diff --git a/AGENTS.md b/AGENTS.md index f6356421ea..83b20fd344 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -288,10 +288,18 @@ func NewProvider(config schemas.ProviderConfig) (*Provider, error) { MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost, // configurable, default 5000 MaxIdleConnDuration: 30 * time.Second, } - return &Provider{client: client, ...}, nil + // After ConfigureProxy/ConfigureDialer/ConfigureTLS, build a sibling client + // for streaming. BuildStreamingClient zeros ReadTimeout/WriteTimeout/MaxConnDuration + // so streams aren't killed by fasthttp's whole-response deadline; per-chunk idle + // is enforced at the app layer via NewIdleTimeoutReader. + streamingClient := providerUtils.BuildStreamingClient(client) + return &Provider{client: client, streamingClient: streamingClient, ...}, nil } ``` -**Note:** Bedrock uses `net/http` (not fasthttp) with HTTP/2 support. Its `http.Transport` is configured with `ForceAttemptHTTP2: true` and `MaxConnsPerHost` from `NetworkConfig` to allow multiple HTTP/2 connections when the server's per-connection stream limit (100 for AWS Bedrock) is reached. + +**Streaming vs unary client:** Every provider holds two clients — `client` for unary requests (`ReadTimeout=30s` bounds the whole response) and `streamingClient` for SSE / EventStream / chunked paths (`ReadTimeout=0`; the per-chunk `NewIdleTimeoutReader` is the only governor). Pass `provider.streamingClient` to every `Handle*Streaming` / `Handle*StreamRequest` helper and to direct `Do` calls inside `*Stream` methods. For new providers, apply the same pattern — missing the switch means streams get killed at 30s. + +**Note:** Bedrock uses `net/http` (not fasthttp) with HTTP/2 support. Its `http.Transport` is configured with `ForceAttemptHTTP2: true` and `MaxConnsPerHost` from `NetworkConfig` to allow multiple HTTP/2 connections when the server's per-connection stream limit (100 for AWS Bedrock) is reached. Use `providerUtils.BuildStreamingHTTPClient(client)` to derive the streaming variant — it shares the base `Transport` (safe for concurrent reuse) but clears `Client.Timeout`. ### The Provider Interface @@ -509,6 +517,21 @@ In `tests/e2e/core/`, **never marshal API payloads to a `Record`/`Map`/plain-obj ## Testing +### Always prefer `make test-core` over raw `go test` for provider-level tests + +The `make test-core` target is the canonical harness for provider tests — it wires up env vars from `.env` (provider API keys), invokes the per-provider `{provider}_test.go` entrypoint in `core/providers//`, and routes through the shared `core/internal/llmtests/` scenario suite that validates end-to-end behavior (including streaming). + +Running bare `go test ./core/providers//...` only executes unit tests and skips the llmtests scenarios — so it won't catch regressions in streaming, tool-calling, or provider-specific response shapes. + +```bash +make test-core PROVIDER=anthropic TESTCASE=TestChatCompletionStream # exact test +make test-core PROVIDER=openai PATTERN=Stream # substring match +make test-core PROVIDER=bedrock # all scenarios for one provider +make test-core DEBUG=1 PROVIDER=gemini TESTCASE=TestResponsesStream # attach Delve on :2345 +``` + +`PATTERN` and `TESTCASE` are mutually exclusive. Provider name must match a directory under `core/providers/` (e.g. `anthropic`, `openai`, `bedrock`, `vertex`, `azure`, `gemini`, `cohere`, `mistral`, `groq`, etc.). + ### LLM Tests (`core/internal/llmtests/`) Scenario-based tests that run against **live provider APIs** with dual-API testing (Chat Completions + Responses API): @@ -648,6 +671,7 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso - **Converter functions**: Pure — no side effects, no logging, no HTTP. - **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`). - **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. +- **Go filenames**: No underscores. The only permitted underscore is the `_test.go` suffix. Examples: `pluginpipeline.go`, `pluginpipeline_test.go` — never `plugin_pipeline.go` or `plugin_pipeline_race_test.go`. Concatenate words (lowercase, no separators) for multi-word filenames. # Frontend Code Guidelines & Patterns diff --git a/core/bifrost.go b/core/bifrost.go index 83ec4777af..cf2a095a67 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -171,7 +171,12 @@ type PluginPipeline struct { preHookErrors []error postHookErrors []error - // Streaming post-hook timing accumulation (for aggregated spans) + // streamingMu guards the streaming post-hook accumulators below. Per-chunk + // writes (accumulatePluginTiming) run in the provider goroutine while the + // end-of-stream finalizer (FinalizeStreamingPostHookSpans) and + // resetPluginPipeline can run in a different goroutine, so unsynchronised + // access triggers "concurrent map read and map write" panics. + streamingMu sync.Mutex postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) chunkCount int @@ -5563,59 +5568,78 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas originalModelRequested := model // resolvedModel is set inside the handler closures below on every attempt so that each - // key's own alias mapping is applied. postHookRunner captures resolvedModel by reference - // (Go closure semantics) and will therefore always see the value from the last attempt. + // key's own alias mapping is applied. The outer var holds the LAST attempt's value and is + // read single-threaded by the worker after retries finish (e.g. the error-fallback at + // line 5653). Streaming postHookRunner must NOT capture this var by reference — it + // snapshots its own attemptResolvedModel inside the per-attempt closure. var resolvedModel string - - // Create plugin pipeline for streaming requests outside retry loop to prevent leaks - var postHookRunner schemas.PostHookRunner - var pipeline *PluginPipeline - if IsStreamRequestType(req.RequestType) { - pipeline = bifrost.getPluginPipeline() - postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) - // can read requestType/provider/model from the chunk or error. - // resolvedModel is captured by reference and reflects the alias from the last attempt. - if result != nil { - result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) - } - if err != nil { - err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) - } - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) - if IsFinalChunk(ctx) { - drainAndAttachPluginLogs(ctx) - } - if bifrostErr != nil { - bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) - return nil, bifrostErr - } else if resp != nil { - resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) - } - return resp, nil - } - // Store a finalizer callback to create aggregated post-hook spans at stream end. - // Wrapped in sync.Once so the normal end-of-stream invocation and a deferred - // safety-net invocation (e.g. from a provider goroutine's panic path) cannot - // double-release the pipeline. - var finalizerOnce sync.Once - postHookSpanFinalizer := func(ctx context.Context) { - finalizerOnce.Do(func() { - pipeline.FinalizeStreamingPostHookSpans(ctx) - bifrost.releasePluginPipeline(pipeline) - }) - } - req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) - } - - // Execute request with retries. Each handler invocation resolves the alias for the key - // selected by keyProvider on that attempt and mutates the worker-local request model. - // resolvedModel (captured by reference in postHookRunner) is updated accordingly. + // lastAttemptFinalizer captures the LAST attempt's postHookSpanFinalizer for the + // worker-level error fallback below. Single-threaded write (assigned by the retry + // loop's per-attempt closure) and single-threaded read (after retries finish), so + // no synchronization needed. Earlier attempts' finalizers fire via their provider + // goroutines' defers — passed via the postHookSpanFinalizer parameter directly to + // handleProviderStreamRequest, never via the shared req.Context. + var lastAttemptFinalizer func(context.Context) + + // Execute request with retries. For streaming, the plugin pipeline, + // postHookRunner, and finalizer are allocated per-attempt inside the + // request handler closure. If they were request-scoped, a retry + // triggered by CheckFirstStreamChunkForError could run against a + // pipeline the previous attempt's provider goroutine has already + // returned to the pool via its deferred finalizer. if IsStreamRequestType(req.RequestType) { stream, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { resolvedModel = k.Aliases.Resolve(originalModelRequested) req.SetModel(resolvedModel) - return bifrost.handleProviderStreamRequest(provider, req, k, postHookRunner) + // Snapshot per-attempt so postHookRunner doesn't observe a later retry's + // alias while this attempt's provider goroutine is still emitting chunks. + attemptResolvedModel := resolvedModel + pipeline := bifrost.getPluginPipeline() + postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + // Uses the per-attempt snapshot — capturing the outer resolvedModel by + // reference would let a later retry's alias bleed into this attempt's chunks. + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + } + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + return nil, bifrostErr + } else if resp != nil { + resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + } + return resp, nil + } + // Store a finalizer callback to create aggregated post-hook spans at stream end. + // Wrapped in sync.Once so the normal end-of-stream invocation and a deferred + // safety-net invocation (e.g. from a provider goroutine's panic path) cannot + // double-release the pipeline. + var finalizerOnce sync.Once + postHookSpanFinalizer := func(ctx context.Context) { + finalizerOnce.Do(func() { + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }) + } + lastAttemptFinalizer = postHookSpanFinalizer + streamCh, streamErr := bifrost.handleProviderStreamRequest(provider, req, k, postHookRunner, postHookSpanFinalizer) + // If stream setup failed before any provider goroutine started, + // no deferred finalizer will run — release the pipeline directly + // so a retry doesn't inherit a leaked pool entry. + if streamErr != nil && streamCh == nil { + finalizerOnce.Do(func() { + bifrost.releasePluginPipeline(pipeline) + }) + } + return streamCh, streamErr }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } else { result, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { @@ -5625,11 +5649,20 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } - // Release pipeline immediately for non-streaming requests only - // For streaming, the pipeline is released in the postHookSpanFinalizer after streaming completes - // Exception: if streaming request has an error, release immediately since finalizer won't be called - if pipeline != nil && (!IsStreamRequestType(req.RequestType) || bifrostError != nil) { - bifrost.releasePluginPipeline(pipeline) + // For streaming with an error, route release through the LAST attempt's + // finalizer (wrapped in sync.Once) so we don't double-Put into the pool + // or race the provider goroutine's deferred FinalizeStreamingPostHookSpans + // call. lastAttemptFinalizer is set inside the per-attempt closure on every + // iteration; after retries finish, it holds the LAST attempt's finalizer. + // Earlier attempts' finalizers have already fired via their provider + // goroutines' defers (passed via the postHookSpanFinalizer parameter + // directly to handleProviderStreamRequest). For streaming without error, + // the finalizer is invoked by completeDeferredSpan / the provider + // goroutine's defer. + if IsStreamRequestType(req.RequestType) && bifrostError != nil { + if lastAttemptFinalizer != nil { + lastAttemptFinalizer(req.Context) + } } if bifrostError != nil { @@ -5985,36 +6018,36 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config } // handleProviderStreamRequest handles the stream request to the provider based on the request type -func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context)) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() if chatRequest != nil { - return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), key, chatRequest) + return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), postHookSpanFinalizer, key, chatRequest) } } - return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) + return provider.TextCompletionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() if responsesRequest != nil { - return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), key, responsesRequest) + return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), postHookSpanFinalizer, key, responsesRequest) } } - return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) + return provider.ChatCompletionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ChatRequest) case schemas.ResponsesStreamRequest: - return provider.ResponsesStream(req.Context, postHookRunner, key, req.BifrostRequest.ResponsesRequest) + return provider.ResponsesStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ResponsesRequest) case schemas.SpeechStreamRequest: - return provider.SpeechStream(req.Context, postHookRunner, key, req.BifrostRequest.SpeechRequest) + return provider.SpeechStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.SpeechRequest) case schemas.TranscriptionStreamRequest: - return provider.TranscriptionStream(req.Context, postHookRunner, key, req.BifrostRequest.TranscriptionRequest) + return provider.TranscriptionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.TranscriptionRequest) case schemas.ImageGenerationStreamRequest: - return provider.ImageGenerationStream(req.Context, postHookRunner, key, req.BifrostRequest.ImageGenerationRequest) + return provider.ImageGenerationStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ImageGenerationRequest) case schemas.ImageEditStreamRequest: - return provider.ImageEditStream(req.Context, postHookRunner, key, req.BifrostRequest.ImageEditRequest) + return provider.ImageEditStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ImageEditRequest) case schemas.PassthroughStreamRequest: - return provider.PassthroughStream(req.Context, postHookRunner, key, req.BifrostRequest.PassthroughRequest) + return provider.PassthroughStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.PassthroughRequest) default: _, model, _ := req.BifrostRequest.GetRequestFields() return nil, &schemas.BifrostError{ @@ -6316,7 +6349,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche } // Increment chunk count for streaming if isStreaming { + p.streamingMu.Lock() p.chunkCount++ + p.streamingMu.Unlock() } // Final logic: if both are set, error takes precedence, unless error is nil if bifrostErr != nil { @@ -6445,20 +6480,39 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s // which nils out their pluginLogs pointer. The drain reads from the shared store // on the root context, so it must happen while the store is still referenced. func (p *PluginPipeline) resetPluginPipeline() { + // Drop cross-request references while the object sits in the pool. + // getPluginPipeline rebinds all four on acquisition, so nil'ing here + // only affects GC hygiene — important when plugins are hot-swapped. + p.llmPlugins = nil + p.mcpPlugins = nil p.executedPreHooks = 0 + clear(p.preHookErrors) p.preHookErrors = p.preHookErrors[:0] + clear(p.postHookErrors) p.postHookErrors = p.postHookErrors[:0] - // Reset streaming timing accumulation + // Reset streaming timing accumulation under lock — the provider goroutine's + // deferred finalizer may still be iterating these fields when the pipeline + // is returned to the pool. logger/tracer are nilled here too so the write + // is synchronized with the finalizer's read under the same mutex. + p.streamingMu.Lock() + p.logger = nil + p.tracer = nil p.chunkCount = 0 if p.postHookTimings != nil { + // clear() drops *pluginTimingAccumulator values (freeing them for GC) + // while retaining the map's backing hash table for reuse. clear(p.postHookTimings) } + // clear() zeros elements in [0, len) — scrub before [:0] so the backing + // array doesn't retain live string references once the slice is truncated. + clear(p.postHookPluginOrder) p.postHookPluginOrder = p.postHookPluginOrder[:0] // Release cached scoped contexts for streaming for _, scopedCtx := range p.streamScopedCtxs { scopedCtx.ReleasePluginScope() } p.streamScopedCtxs = nil + p.streamingMu.Unlock() } // drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext @@ -6477,6 +6531,8 @@ func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) { // accumulatePluginTiming accumulates timing for a plugin during streaming func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) { + p.streamingMu.Lock() + defer p.streamingMu.Unlock() if p.postHookTimings == nil { p.postHookTimings = make(map[string]*pluginTimingAccumulator) } @@ -6498,7 +6554,40 @@ func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time // This should be called once at the end of streaming to create one span per plugin with average timing. // Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one). func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { - if p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 { + // Snapshot the accumulators under lock so per-chunk writers in the + // provider goroutine can't race with the finalizer. Tracer calls below + // run unlocked — we don't want to stall chunk writers on span I/O. + type snapshotEntry struct { + pluginName string + totalDuration time.Duration + invocations int + errors int + } + p.streamingMu.Lock() + // Capture tracer under the same lock that guards resetPluginPipeline's + // writes so the read/write pair on p.tracer is synchronized and the + // unlocked tracer calls below use a stable local. + tracer := p.tracer + if tracer == nil || p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 { + p.streamingMu.Unlock() + return + } + snapshot := make([]snapshotEntry, 0, len(p.postHookPluginOrder)) + for _, pluginName := range p.postHookPluginOrder { + timing, ok := p.postHookTimings[pluginName] + if !ok || timing.invocations == 0 { + continue + } + snapshot = append(snapshot, snapshotEntry{ + pluginName: pluginName, + totalDuration: timing.totalDuration, + invocations: timing.invocations, + errors: timing.errors, + }) + } + p.streamingMu.Unlock() + + if len(snapshot) == 0 { return } @@ -6507,50 +6596,47 @@ func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { handle schemas.SpanHandle hasErrors bool } - spans := make([]spanInfo, 0, len(p.postHookPluginOrder)) + spans := make([]spanInfo, 0, len(snapshot)) currentCtx := ctx // Start spans in execution order (nested: each is a child of the previous) - for _, pluginName := range p.postHookPluginOrder { - timing, ok := p.postHookTimings[pluginName] - if !ok || timing.invocations == 0 { - continue - } - + for _, entry := range snapshot { // Create span as child of the previous span (nested hierarchy) - newCtx, handle := p.tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + newCtx, handle := tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(entry.pluginName)), schemas.SpanKindPlugin) if handle == nil { continue } // Calculate average duration in milliseconds - avgMs := float64(timing.totalDuration.Milliseconds()) / float64(timing.invocations) + avgMs := float64(entry.totalDuration.Milliseconds()) / float64(entry.invocations) // Set aggregated attributes - p.tracer.SetAttribute(handle, schemas.AttrPluginInvocations, timing.invocations) - p.tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) - p.tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, timing.totalDuration.Milliseconds()) + tracer.SetAttribute(handle, schemas.AttrPluginInvocations, entry.invocations) + tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) + tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, entry.totalDuration.Milliseconds()) - if timing.errors > 0 { - p.tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, timing.errors) + if entry.errors > 0 { + tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, entry.errors) } - spans = append(spans, spanInfo{handle: handle, hasErrors: timing.errors > 0}) + spans = append(spans, spanInfo{handle: handle, hasErrors: entry.errors > 0}) currentCtx = newCtx } // End spans in reverse order (innermost first, like unwinding a call stack) for i := len(spans) - 1; i >= 0; i-- { if spans[i].hasErrors { - p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") + tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") } else { - p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") + tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") } } } // GetChunkCount returns the number of chunks processed during streaming func (p *PluginPipeline) GetChunkCount() int { + p.streamingMu.Lock() + defer p.streamingMu.Unlock() return p.chunkCount } diff --git a/core/go.sum b/core/go.sum index 685035b381..5869ae64b8 100644 --- a/core/go.sum +++ b/core/go.sum @@ -33,13 +33,17 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22/go.mod h1:zd/JsJ4P7oGfUhXn1VyLqaRZwPmZwg44Jf2dS84Dm3Y= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 h1:JRaIgADQS/U6uXDqlPiefP32yXTda7Kqfx+LgspooZM= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13/go.mod h1:CEuVn5WqOMilYl+tbccq8+N2ieCy0gVn3OtRb0vBNNM= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 h1:ZlvrNcHSFFWURB8avufQq9gFsheUgjVD9536obIknfM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21/go.mod h1:cv3TNhVrssKR0O/xxLJVRfd2oazSnZnkUeTf6ctUwfQ= github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 h1:HwxWTbTrIHm5qY+CAEur0s/figc3qwvLWsNkF4RPToo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3/go.mod h1:uoA43SdFwacedBfSgfFSjjCvYe8aYBS7EnU5GZ/YKMM= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= diff --git a/core/internal/llmtests/chat_completion_stream.go b/core/internal/llmtests/chat_completion_stream.go index 0887da7e0b..e8ae70435f 100644 --- a/core/internal/llmtests/chat_completion_stream.go +++ b/core/internal/llmtests/chat_completion_stream.go @@ -164,6 +164,15 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont t.Logf("⚠️ Warning: Response ID is empty") } + // Per-chunk Object validation: bifrost normalizes every streaming chunk + // to the OpenAI shape with Object="chat.completion.chunk", whether the + // upstream provider natively emits it (OpenAI family) or bifrost + // synthesizes it during translation (e.g., Anthropic's type-keyed events). + // A missing/wrong Object here indicates a provider translation regression. + if response.BifrostChatResponse.Object != "chat.completion.chunk" { + t.Errorf("Chunk %d: Object field must be 'chat.completion.chunk', got %q", responseCount+1, response.BifrostChatResponse.Object) + } + // Log latency for each chunk (can be 0 for inter-chunks) t.Logf("📊 Chunk %d latency: %d ms", responseCount+1, response.BifrostChatResponse.ExtraFields.Latency) diff --git a/core/internal/llmtests/response_validation.go b/core/internal/llmtests/response_validation.go index bc75dd07df..788436b994 100644 --- a/core/internal/llmtests/response_validation.go +++ b/core/internal/llmtests/response_validation.go @@ -94,7 +94,7 @@ func ValidateChatResponse(t *testing.T, response *schemas.BifrostChatResponse, e } // Validate basic structure - validateChatBasicStructure(t, response, expectations, &result) + validateChatBasicStructure(t, response, expectations, &result, scenarioName) // Validate content validateChatContent(t, response, expectations, &result) @@ -445,11 +445,17 @@ func ValidateCountTokensResponse(t *testing.T, response *schemas.BifrostCountTok // ============================================================================= // validateChatBasicStructure checks the basic structure of the chat response -func validateChatBasicStructure(t *testing.T, response *schemas.BifrostChatResponse, expectations ResponseExpectations, result *ValidationResult) { - // Check that Object field is not empty (should be "chat.completion" or "chat.completion.chunk") - if response.Object == "" { - result.Passed = false - result.Errors = append(result.Errors, "Object field is empty in chat completion response") +func validateChatBasicStructure(t *testing.T, response *schemas.BifrostChatResponse, expectations ResponseExpectations, result *ValidationResult, scenarioName string) { + // Object is a constant bifrost schema marker ("chat.completion" / "chat.completion.chunk"). + // For streaming scenarios, per-chunk validation in chat_completion_stream.go covers this — + // the aggregated/consolidated response built by the harness is a synthetic construct and + // does not carry provider-originating semantics. Skip the check there to avoid asserting + // that the harness remembered to copy a constant forward. + if !strings.Contains(scenarioName, "Stream") { + if response.Object == "" { + result.Passed = false + result.Errors = append(result.Errors, "Object field is empty in chat completion response") + } } // Check choice count diff --git a/core/pluginpipelinerace_test.go b/core/pluginpipelinerace_test.go new file mode 100644 index 0000000000..4d5abe58e0 --- /dev/null +++ b/core/pluginpipelinerace_test.go @@ -0,0 +1,77 @@ +package bifrost + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +// TestPluginPipelineStreamingRace reproduces the production panic: +// +// fatal error: concurrent map read and map write +// (*PluginPipeline).FinalizeStreamingPostHookSpans +// +// It hammers accumulatePluginTiming (per-chunk writer) concurrently with +// FinalizeStreamingPostHookSpans (end-of-stream reader) and resetPluginPipeline +// (pool-release writer). Before the streamingMu fix these three paths had no +// synchronisation and the -race detector / runtime map check would trip +// immediately. Run with: go test -race -run PluginPipelineStreamingRace +func TestPluginPipelineStreamingRace(t *testing.T) { + p := &PluginPipeline{ + logger: NewDefaultLogger(schemas.LogLevelError), + tracer: &schemas.NoOpTracer{}, + } + + const writers = 8 + const iterations = 2000 + + var wg sync.WaitGroup + + // Per-chunk accumulator writers — simulate multiple plugins accumulating + // timing for every streamed chunk. + for w := 0; w < writers; w++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + pluginName := fmt.Sprintf("plugin-%d", id%3) // a few distinct plugin keys + for i := 0; i < iterations; i++ { + p.accumulatePluginTiming(pluginName, time.Microsecond, i%17 == 0) + } + }(w) + } + + // End-of-stream finalizer racing with writers. + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for i := 0; i < iterations/10; i++ { + p.FinalizeStreamingPostHookSpans(ctx) + } + }() + + // resetPluginPipeline racing with writers — simulates the pool returning + // the pipeline to another request mid-flight. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations/10; i++ { + p.resetPluginPipeline() + } + }() + + // Concurrent GetChunkCount readers. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _ = p.GetChunkCount() + } + }() + + wg.Wait() +} diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 0fc6073ced..1d011ee0f2 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -24,7 +24,8 @@ import ( // AnthropicProvider implements the Provider interface for Anthropic's Claude API. type AnthropicProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) apiVersion string // API version for the provider networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse @@ -101,6 +102,7 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.anthropic.com" @@ -110,6 +112,7 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) return &AnthropicProvider{ logger: logger, client: client, + streamingClient: streamingClient, apiVersion: "2023-06-01", networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, @@ -408,7 +411,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k // TextCompletionStream performs a streaming text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *AnthropicProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -508,7 +511,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k // ChatCompletionStream performs a streaming chat completion request to the Anthropic API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -566,7 +569,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont // Use shared Anthropic streaming logic return HandleAnthropicChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionStreamRequest), jsonData, headers, @@ -578,6 +581,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -597,6 +601,7 @@ func HandleAnthropicChatCompletionStreaming( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -669,12 +674,12 @@ func HandleAnthropicChatCompletionStreaming( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -686,7 +691,7 @@ func HandleAnthropicChatCompletionStreaming( fmt.Errorf("provider returned an empty response"), ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -737,7 +742,7 @@ func HandleAnthropicChatCompletionStreaming( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) return } break @@ -855,7 +860,7 @@ func HandleAnthropicChatCompletionStreaming( response.ExtraFields.RawResponse = eventData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) continue } } @@ -870,7 +875,7 @@ func HandleAnthropicChatCompletionStreaming( response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream(ctx, structuredOutputToolName, streamState) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger, postHookSpanFinalizer) break } if response != nil { @@ -893,7 +898,7 @@ func HandleAnthropicChatCompletionStreaming( response.ExtraFields.RawResponse = eventData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } if isLastChunk { break @@ -919,7 +924,7 @@ func HandleAnthropicChatCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) }() return responseChan, nil @@ -991,7 +996,7 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc } // ResponsesStream performs a streaming responses request to the Anthropic API. -func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } @@ -1018,7 +1023,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, return HandleAnthropicResponsesStream( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesStreamRequest), jsonBody, headers, @@ -1030,6 +1035,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -1049,6 +1055,7 @@ func HandleAnthropicResponsesStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -1123,12 +1130,12 @@ func HandleAnthropicResponsesStream( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1140,7 +1147,7 @@ func HandleAnthropicResponsesStream( fmt.Errorf("provider returned an empty response"), ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -1189,7 +1196,7 @@ func HandleAnthropicResponsesStream( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -1264,7 +1271,7 @@ func HandleAnthropicResponsesStream( return } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger, postHookSpanFinalizer) break } // Passthrough: when conversion returns no responses but we need to forward raw events, @@ -1284,7 +1291,7 @@ func HandleAnthropicResponsesStream( chunkIndex++ providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, passthroughResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) continue } // Handle each response in the slice @@ -1324,10 +1331,10 @@ func HandleAnthropicResponsesStream( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) } } @@ -1857,7 +1864,7 @@ func (provider *AnthropicProvider) Speech(ctx *schemas.BifrostContext, key schem } // SpeechStream is not supported by the Anthropic provider. -func (provider *AnthropicProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -1867,7 +1874,7 @@ func (provider *AnthropicProvider) Transcription(ctx *schemas.BifrostContext, ke } // TranscriptionStream is not supported by the Anthropic provider. -func (provider *AnthropicProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1877,7 +1884,7 @@ func (provider *AnthropicProvider) ImageGeneration(ctx *schemas.BifrostContext, } // ImageGenerationStream is not supported by the Anthropic provider. -func (provider *AnthropicProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -1887,7 +1894,7 @@ func (provider *AnthropicProvider) ImageEdit(ctx *schemas.BifrostContext, key sc } // ImageEditStream is not supported by the Anthropic provider. -func (provider *AnthropicProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AnthropicProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -2585,6 +2592,7 @@ func (provider *AnthropicProvider) Passthrough( func (provider *AnthropicProvider) PassthroughStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -2622,7 +2630,7 @@ func (provider *AnthropicProvider) PassthroughStream( fasthttpReq.SetBody(req.Body) - activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp) + activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp) if err := activeClient.Do(fasthttpReq, resp); err != nil { providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -2669,12 +2677,12 @@ func (provider *AnthropicProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } close(ch) }() @@ -2712,8 +2720,8 @@ func (provider *AnthropicProvider) PassthroughStream( }, } postHookRunner(ctx, finalResp, nil) - if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { - finalizer(ctx) + if postHookSpanFinalizer != nil { + postHookSpanFinalizer(ctx) } return } @@ -2723,7 +2731,7 @@ func (provider *AnthropicProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer) return } } diff --git a/core/providers/anthropic/chat_test.go b/core/providers/anthropic/chat_test.go index b73002009b..fd0a49cb6b 100644 --- a/core/providers/anthropic/chat_test.go +++ b/core/providers/anthropic/chat_test.go @@ -378,72 +378,42 @@ func TestToBifrostChatResponse_MultipleTextBlocksWithThinking(t *testing.T) { t.Fatal("expected non-nil result") } - // Content should be a combined string, not blocks + // With multiple text blocks, ToBifrostChatResponse preserves them as ContentBlocks + // (only a single text block collapses to ContentStr — see chat.go:812-815). + // Thinking flows through ReasoningDetails below, not ContentStr. choice := result.Choices[0] msg := choice.ChatNonStreamResponseChoice.Message - if msg.Content.ContentBlocks != nil { - t.Error("expected ContentBlocks to be nil (combined into string)") + if msg.Content.ContentStr != nil { + t.Errorf("expected ContentStr to be nil with multiple text blocks, got %q", *msg.Content.ContentStr) } - if msg.Content.ContentStr == nil { - t.Fatal("expected ContentStr to be non-nil") + if len(msg.Content.ContentBlocks) != 2 { + t.Fatalf("expected 2 content blocks (one per text block), got %d", len(msg.Content.ContentBlocks)) } - - // Combined string: thinking first, then text blocks - expected := thinkingText + "\n\n" + textBlock1 + "\n\n" + textBlock2 - if *msg.Content.ContentStr != expected { - t.Errorf("expected combined content:\n%s\ngot:\n%s", expected, *msg.Content.ContentStr) + if msg.Content.ContentBlocks[0].Text == nil || *msg.Content.ContentBlocks[0].Text != textBlock1 { + t.Errorf("block 0 text mismatch: got %v, want %q", msg.Content.ContentBlocks[0].Text, textBlock1) + } + if msg.Content.ContentBlocks[1].Text == nil || *msg.Content.ContentBlocks[1].Text != textBlock2 { + t.Errorf("block 1 text mismatch: got %v, want %q", msg.Content.ContentBlocks[1].Text, textBlock2) } - // Reasoning field should still have thinking text + // Thinking is surfaced via ReasoningDetails with the signature preserved + // (see chat.go:798-807). if msg.ChatAssistantMessage == nil { t.Fatal("expected ChatAssistantMessage to be non-nil") } - if msg.ChatAssistantMessage.Reasoning == nil { - t.Fatal("expected Reasoning to be non-nil") - } - - // ReasoningDetails should have: signature-only thinking entry + content blocks boundary rd := msg.ChatAssistantMessage.ReasoningDetails - if len(rd) < 2 { - t.Fatalf("expected at least 2 reasoning details entries, got %d", len(rd)) + if len(rd) != 1 { + t.Fatalf("expected 1 reasoning details entry (the thinking block), got %d", len(rd)) } - - // First entry: thinking with signature, no text (text was cleared) if rd[0].Type != schemas.BifrostReasoningDetailsTypeText { - t.Errorf("expected first reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type) + t.Errorf("expected reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type) } if rd[0].Signature == nil || *rd[0].Signature != signature { - t.Error("expected signature to be preserved") - } - if rd[0].Text != nil { - t.Error("expected thinking text to be nil (cleared to avoid duplication)") - } - - // Last entry: content blocks boundary - lastRD := rd[len(rd)-1] - if lastRD.Type != schemas.BifrostReasoningDetailsTypeContentBlocks { - t.Errorf("expected last reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeContentBlocks, lastRD.Type) - } - if lastRD.Text == nil { - t.Fatal("expected content blocks metadata to be non-nil") - } - - // var meta []contentBlockMeta - // if err := json.Unmarshal([]byte(*lastRD.Text), &meta); err != nil { - // t.Fatalf("failed to unmarshal block metadata: %v", err) - // } - // if len(meta) != 3 { - // t.Fatalf("expected 3 block metadata entries, got %d", len(meta)) - // } - // if meta[0].T != "thinking" || meta[0].L != len(thinkingText) { - // t.Errorf("block 0: expected thinking/%d, got %s/%d", len(thinkingText), meta[0].T, meta[0].L) - // } - // if meta[1].T != "text" || meta[1].L != len(textBlock1) { - // t.Errorf("block 1: expected text/%d, got %s/%d", len(textBlock1), meta[1].T, meta[1].L) - // } - // if meta[2].T != "text" || meta[2].L != len(textBlock2) { - // t.Errorf("block 2: expected text/%d, got %s/%d", len(textBlock2), meta[2].T, meta[2].L) - // } + t.Error("expected thinking signature to be preserved on reasoning detail") + } + if rd[0].Text == nil || *rd[0].Text != thinkingText { + t.Errorf("expected reasoning text to match thinking text") + } } func TestToBifrostChatResponse_SingleTextBlockNoThinking(t *testing.T) { diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 323d13584e..beac667dd0 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -35,9 +35,10 @@ const DefaultAzureScope = "https://cognitiveservices.azure.com/.default" // AzureProvider implements the Provider interface for Azure's API. type AzureProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests - networkConfig schemas.NetworkConfig // Network configuration including extra headers + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) + networkConfig schemas.NetworkConfig // Network configuration including extra headers credentials sync.Map // map of tenant ID:client ID to azcore.TokenCredential sendBackRawRequest bool // Whether to include raw request in BifrostResponse @@ -184,9 +185,11 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) return &AzureProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -467,7 +470,7 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s // TextCompletionStream performs a streaming text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -483,7 +486,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, url, request, authHeader, @@ -496,6 +499,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -596,7 +600,7 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var url string if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) @@ -628,7 +632,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, url, jsonData, authHeader, @@ -640,6 +644,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -655,7 +660,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // Use shared streaming logic from OpenAI return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, url, request, authHeader, @@ -670,6 +675,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } } @@ -763,7 +769,7 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema } // ResponsesStream performs a streaming responses request to Azure's API. -func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var url string if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) @@ -781,7 +787,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicResponsesStream( ctx, - provider.client, + provider.streamingClient, url, jsonData, authHeader, @@ -793,6 +799,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -804,7 +811,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post // Use shared streaming logic from OpenAI return openai.HandleOpenAIResponsesStreaming( ctx, - provider.client, + provider.streamingClient, url, request, authHeader, @@ -818,6 +825,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post nil, nil, provider.logger, + postHookSpanFinalizer, ) } } @@ -933,7 +941,7 @@ func (provider *AzureProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, // SpeechStream handles streaming for speech synthesis with Azure. // Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON. -func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1031,12 +1039,12 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1138,7 +1146,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo if errParseErr := sonic.Unmarshal(audioData, &bifrostErr); errParseErr == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } } @@ -1167,7 +1175,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo response.ExtraFields.RawResponse = audioData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan, postHookSpanFinalizer) } // Check if we received [DONE] marker - break outer loop to send final response @@ -1188,7 +1196,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // a fake "done" response with truncated audio. ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -1212,7 +1220,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo finalResponse.BackfillParams(request) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &finalResponse, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &finalResponse, nil, nil), responseChan, postHookSpanFinalizer) } else if chunkIndex >= 0 && !doneReceived { provider.logger.Warn("Stream ended without receiving [DONE] marker after %d chunks", chunkIndex+1) } @@ -1253,7 +1261,7 @@ func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key sc } // TranscriptionStream is not supported by the Azure provider. -func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1297,6 +1305,7 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key func (provider *AzureProvider) ImageGenerationStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -1320,7 +1329,7 @@ func (provider *AzureProvider) ImageGenerationStream( // Azure is OpenAI-compatible return openai.HandleOpenAIImageGenerationStreaming( ctx, - provider.client, + provider.streamingClient, url, request, authHeader, @@ -1333,6 +1342,7 @@ func (provider *AzureProvider) ImageGenerationStream( nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -1370,7 +1380,7 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema } // ImageEditStream performs a streaming image edit request to Azure's API. -func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1391,7 +1401,7 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post // Azure is OpenAI-compatible return openai.HandleOpenAIImageEditStreamRequest( ctx, - provider.client, + provider.streamingClient, url, request, authHeader, @@ -1404,6 +1414,7 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -2766,6 +2777,7 @@ func (provider *AzureProvider) Passthrough( func (provider *AzureProvider) PassthroughStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -2797,7 +2809,7 @@ func (provider *AzureProvider) PassthroughStream( fasthttpReq.SetBody(req.Body) - activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp) + activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp) providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds) startTime := time.Now() @@ -2839,12 +2851,12 @@ func (provider *AzureProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } close(ch) }() @@ -2882,8 +2894,8 @@ func (provider *AzureProvider) PassthroughStream( }, } postHookRunner(ctx, finalResp, nil) - if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { - finalizer(ctx) + if postHookSpanFinalizer != nil { + postHookSpanFinalizer(ctx) } return } @@ -2893,7 +2905,7 @@ func (provider *AzureProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer) return } } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 6b7cf700f0..5d93f54b93 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -33,7 +33,8 @@ import ( // BedrockProvider implements the Provider interface for AWS Bedrock. type BedrockProvider struct { logger schemas.Logger // Logger for provider operations - client *http.Client // HTTP client for API requests + client *http.Client // HTTP client for unary API requests (Client.Timeout bounds overall response) + streamingClient *http.Client // HTTP client for streaming API requests (no Timeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers customProviderConfig *schemas.CustomProviderConfig // Custom provider config sendBackRawRequest bool // Whether to include raw request in BifrostResponse @@ -114,6 +115,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) ( } client := &http.Client{Transport: transport, Timeout: requestTimeout} + streamingClient := providerUtils.BuildStreamingHTTPClient(client) // Pre-warm response pools for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { @@ -123,6 +125,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) ( return &BedrockProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, customProviderConfig: config.CustomProviderConfig, sendBackRawRequest: config.SendBackRawRequest, @@ -456,7 +459,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex } // Make the request - resp, respErr := provider.client.Do(req) + resp, respErr := provider.streamingClient.Do(req) if respErr != nil { if errors.Is(respErr, context.Canceled) { return nil, &schemas.BifrostError{ @@ -881,7 +884,7 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil { return nil, err } @@ -912,12 +915,12 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -963,9 +966,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - }, responseChan, provider.logger) + }, responseChan, provider.logger, postHookSpanFinalizer) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) } return } @@ -996,10 +999,10 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: fmt.Sprintf("%s stream %s: %s", providerName, excType, errMsg), }, - }, responseChan, provider.logger) + }, responseChan, provider.logger, postHookSpanFinalizer) } else { err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) } return } @@ -1011,7 +1014,7 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex } if err := sonic.Unmarshal(message.Payload, &chunkPayload); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -1024,7 +1027,7 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex }, } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(textResponse, nil, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(textResponse, nil, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } } }() @@ -1111,7 +1114,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key // ChatCompletionStream performs a streaming chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the streaming response. // Returns a channel for streaming BifrostStreamChunk objects or an error if the request fails. -func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -1138,12 +1141,12 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds) // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1209,9 +1212,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - }, responseChan, provider.logger) + }, responseChan, provider.logger, postHookSpanFinalizer) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) } return } @@ -1239,9 +1242,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: err.Error(), }, - }, responseChan, provider.logger) + }, responseChan, provider.logger, postHookSpanFinalizer) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) } return } @@ -1251,7 +1254,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -1341,7 +1344,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response.ExtraFields.RawResponse = string(message.Payload) } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) continue } } @@ -1349,7 +1352,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response, bifrostErr, _ := streamEvent.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } if response != nil { @@ -1366,7 +1369,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response.ExtraFields.RawResponse = string(message.Payload) } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } } } @@ -1383,7 +1386,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) }() return responseChan, nil @@ -1460,7 +1463,7 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche // ResponsesStream performs a streaming chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the streaming response. // Returns a channel for streaming BifrostResponse objects or an error if the request fails. -func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } @@ -1489,12 +1492,12 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1567,7 +1570,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil, nil), responseChan, postHookSpanFinalizer) } break } @@ -1582,9 +1585,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Message: schemas.ErrProviderNetworkError, Error: err, }, - }, responseChan, provider.logger) + }, responseChan, provider.logger, postHookSpanFinalizer) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) } return } @@ -1612,9 +1615,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Error: &schemas.ErrorField{ Message: err.Error(), }, - }, responseChan, provider.logger) + }, responseChan, provider.logger, postHookSpanFinalizer) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) } return } @@ -1624,7 +1627,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -1691,7 +1694,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po response.ExtraFields.RawResponse = string(message.Payload) } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) continue } } @@ -1699,7 +1702,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } for _, response := range responses { @@ -1715,7 +1718,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po response.ExtraFields.RawResponse = string(message.Payload) } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) } } } @@ -1894,7 +1897,7 @@ func (provider *BedrockProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke } // SpeechStream is not supported by the Bedrock provider. -func (provider *BedrockProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, schemas.Bedrock) } @@ -1904,7 +1907,7 @@ func (provider *BedrockProvider) Transcription(ctx *schemas.BifrostContext, key } // TranscriptionStream is not supported by the Bedrock provider. -func (provider *BedrockProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, schemas.Bedrock) } @@ -1977,7 +1980,7 @@ func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, ke } // ImageGenerationStream is not supported by the Bedrock provider. -func (provider *BedrockProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, schemas.Bedrock) } @@ -2052,7 +2055,7 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche } // ImageEditStream is not supported by the Bedrock provider. -func (provider *BedrockProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -3605,6 +3608,6 @@ func (provider *BedrockProvider) Passthrough(_ *schemas.BifrostContext, _ schema return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *BedrockProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *BedrockProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/bedrock/transport_test.go b/core/providers/bedrock/transport_test.go index 1e2a447e9d..d8ba5daef6 100644 --- a/core/providers/bedrock/transport_test.go +++ b/core/providers/bedrock/transport_test.go @@ -71,13 +71,18 @@ func newTestProviderWithServer(t *testing.T, ts *httptest.Server) *BedrockProvid targetURL, err := url.Parse(ts.URL) require.NoError(t, err) + redirect := &redirectTransport{ + target: targetURL, + transport: ts.Client().Transport, + } provider.client = &http.Client{ - Transport: &redirectTransport{ - target: targetURL, - transport: ts.Client().Transport, - }, - Timeout: 5 * time.Second, + Transport: redirect, + Timeout: 5 * time.Second, } + // Streaming paths use streamingClient (no Timeout); redirect it to the + // test server too, otherwise Bedrock streaming tests would hit the real + // AWS endpoint. + provider.streamingClient = &http.Client{Transport: redirect} return provider } @@ -177,7 +182,7 @@ func TestChatCompletionStream_StaleConnection_ChunkIsRetryable(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, key, testChatRequest()) + streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest()) if bifrostErr != nil { // Error surfaced synchronously (e.g. connection refused before HTTP 200). @@ -242,7 +247,7 @@ func TestChatCompletionStream_NetOpError_ChunkIsRetryable(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, key, testChatRequest()) + streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest()) if bifrostErr != nil { assert.False(t, bifrostErr.IsBifrostError, "pre-stream network error must be IsBifrostError:false") @@ -326,7 +331,7 @@ func TestChatCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) ctx := testBedrockCtx() key := testBedrockKey() - streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, key, testChatRequest()) + streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest()) require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk") require.NotNil(t, streamChan) @@ -373,7 +378,7 @@ func TestChatCompletionStream_NonRetryableException_IsTerminal(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, key, testChatRequest()) + streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest()) require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk") require.NotNil(t, streamChan) @@ -476,7 +481,7 @@ func TestTextCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) defer ts.Close() provider := newTestProviderWithServer(t, ts) - streamChan, bifrostErr := provider.TextCompletionStream(testBedrockCtx(), noopPostHookRunner, testBedrockKey(), testTextCompletionRequest()) + streamChan, bifrostErr := provider.TextCompletionStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testTextCompletionRequest()) assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus) }) } @@ -512,7 +517,7 @@ func TestResponsesStream_RetryableException_ChunkIsRetryable(t *testing.T) { defer ts.Close() provider := newTestProviderWithServer(t, ts) - streamChan, bifrostErr := provider.ResponsesStream(testBedrockCtx(), noopPostHookRunner, testBedrockKey(), testResponsesRequest()) + streamChan, bifrostErr := provider.ResponsesStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testResponsesRequest()) assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus) }) } diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index c32dcd7374..45292d6d24 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -2,6 +2,7 @@ package cerebras import ( + "context" "strings" "time" @@ -14,7 +15,8 @@ import ( // CerebrasProvider implements the Provider interface for Cerebras's API. type CerebrasProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -41,6 +43,7 @@ func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger) client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.cerebras.ai" @@ -50,6 +53,7 @@ func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger) return &CerebrasProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -99,7 +103,7 @@ func (provider *CerebrasProvider) TextCompletion(ctx *schemas.BifrostContext, ke // TextCompletionStream performs a streaming text completion request to Cerebras's API. // It formats the request, sends it to Cerebras, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -107,7 +111,7 @@ func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostConte // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/completions", request, authHeader, @@ -120,6 +124,7 @@ func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostConte nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -145,7 +150,7 @@ func (provider *CerebrasProvider) ChatCompletion(ctx *schemas.BifrostContext, ke // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Cerebras's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -153,7 +158,7 @@ func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostConte // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/chat/completions", request, authHeader, @@ -168,6 +173,7 @@ func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostConte nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -183,11 +189,12 @@ func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key sch } // ResponsesStream performs a streaming responses request to the Cerebras API. -func (provider *CerebrasProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -214,7 +221,7 @@ func (provider *CerebrasProvider) OCR(ctx *schemas.BifrostContext, key schemas.K } // SpeechStream is not supported by the Cerebras provider. -func (provider *CerebrasProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -224,7 +231,7 @@ func (provider *CerebrasProvider) Transcription(ctx *schemas.BifrostContext, key } // TranscriptionStream is not supported by the Cerebras provider. -func (provider *CerebrasProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -234,7 +241,7 @@ func (provider *CerebrasProvider) ImageGeneration(ctx *schemas.BifrostContext, k } // ImageGenerationStream is not supported by the Cerebras provider. -func (provider *CerebrasProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -244,7 +251,7 @@ func (provider *CerebrasProvider) ImageEdit(ctx *schemas.BifrostContext, key sch } // ImageEditStream is not supported by the Cerebras provider. -func (provider *CerebrasProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -393,6 +400,6 @@ func (provider *CerebrasProvider) Passthrough(_ *schemas.BifrostContext, _ schem return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *CerebrasProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CerebrasProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 1e5d50e087..1c2e371ec8 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -87,7 +87,8 @@ func releaseCohereResponse(resp *CohereChatResponse) { // CohereProvider implements the Provider interface for Cohere. type CohereProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -122,6 +123,8 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* cohereRerankResponsePool.Put(&CohereRerankResponse{}) } + streamingClient := providerUtils.BuildStreamingClient(client) + // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.cohere.ai" @@ -131,6 +134,7 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* return &CohereProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, customProviderConfig: config.CustomProviderConfig, sendBackRawRequest: config.SendBackRawRequest, @@ -328,7 +332,7 @@ func (provider *CohereProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Cohere's API. // It formats the request, sends it to Cohere, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *CohereProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -402,7 +406,7 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key // ChatCompletionStream performs a streaming chat completion request to the Cohere API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if chat completion stream is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err @@ -451,7 +455,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } // Make the request - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if usedLargePayloadBody { providerUtils.DrainLargePayloadRemainder(ctx) } @@ -496,12 +500,12 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -539,7 +543,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -562,7 +566,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) break } if response != nil { @@ -586,10 +590,10 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) break } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } } }() @@ -665,7 +669,7 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem } // ResponsesStream performs a streaming responses request to the Cohere API. -func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if responses stream is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err @@ -715,7 +719,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos } // Make the request - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if usedLargePayloadBody { providerUtils.DrainLargePayloadRemainder(ctx) } @@ -760,12 +764,12 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -808,7 +812,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -829,7 +833,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos responses, bifrostErr, isLastChunk := event.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) break } // Handle each response in the slice @@ -856,10 +860,10 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) } } } @@ -1011,7 +1015,7 @@ func (provider *CohereProvider) Speech(ctx *schemas.BifrostContext, key schemas. } // SpeechStream is not supported by the Cohere provider. -func (provider *CohereProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -1021,7 +1025,7 @@ func (provider *CohereProvider) Transcription(ctx *schemas.BifrostContext, key s } // TranscriptionStream is not supported by the Cohere provider. -func (provider *CohereProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1031,7 +1035,7 @@ func (provider *CohereProvider) ImageGeneration(ctx *schemas.BifrostContext, key } // ImageGenerationStream is not supported by the Cohere provider. -func (provider *CohereProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -1041,7 +1045,7 @@ func (provider *CohereProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } // ImageEditStream is not supported by the Cohere provider. -func (provider *CohereProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -1257,6 +1261,6 @@ func (provider *CohereProvider) Passthrough(_ *schemas.BifrostContext, _ schemas return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *CohereProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *CohereProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index bcd3e5cfc7..0dd630de71 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -21,7 +21,8 @@ import ( type ElevenlabsProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -49,6 +50,7 @@ func NewElevenlabsProvider(config *schemas.ProviderConfig, logger schemas.Logger client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.elevenlabs.io" @@ -58,6 +60,7 @@ func NewElevenlabsProvider(config *schemas.ProviderConfig, logger schemas.Logger return &ElevenlabsProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, customProviderConfig: config.CustomProviderConfig, sendBackRawRequest: config.SendBackRawRequest, @@ -147,7 +150,7 @@ func (provider *ElevenlabsProvider) TextCompletion(ctx *schemas.BifrostContext, } // TextCompletionStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -157,7 +160,7 @@ func (provider *ElevenlabsProvider) ChatCompletion(ctx *schemas.BifrostContext, } // ChatCompletionStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionStreamRequest, provider.GetProviderKey()) } @@ -167,7 +170,7 @@ func (provider *ElevenlabsProvider) Responses(ctx *schemas.BifrostContext, key s } // ResponsesStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesStreamRequest, provider.GetProviderKey()) } @@ -304,7 +307,7 @@ func (provider *ElevenlabsProvider) OCR(ctx *schemas.BifrostContext, key schemas } // SpeechStream performs a text to speech stream request -func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err } @@ -347,7 +350,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // Make request startTime := time.Now() - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -383,9 +386,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -402,7 +405,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer). stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) defer stopCancellation() - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) // read binary audio chunks from the stream // 4KB buffer for reading chunks @@ -427,7 +430,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -451,7 +454,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po response.ExtraFields.RawResponse = audioChunk } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan, postHookSpanFinalizer) } } @@ -470,7 +473,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody) } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil, nil), responseChan, postHookSpanFinalizer) }() return responseChan, nil @@ -711,7 +714,7 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr } // TranscriptionStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -721,7 +724,7 @@ func (provider *ElevenlabsProvider) ImageGeneration(ctx *schemas.BifrostContext, } // ImageGenerationStream is not supported by the Elevenlabs provider. -func (provider *ElevenlabsProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -731,7 +734,7 @@ func (provider *ElevenlabsProvider) ImageEdit(ctx *schemas.BifrostContext, key s } // ImageEditStream is not supported by the Elevenlabs provider. -func (provider *ElevenlabsProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -925,6 +928,6 @@ func (provider *ElevenlabsProvider) Passthrough(_ *schemas.BifrostContext, _ sch return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *ElevenlabsProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/fireworks/fireworks.go b/core/providers/fireworks/fireworks.go index 9897b71efe..827d1777df 100644 --- a/core/providers/fireworks/fireworks.go +++ b/core/providers/fireworks/fireworks.go @@ -2,6 +2,7 @@ package fireworks import ( + "context" "strings" "time" @@ -14,7 +15,8 @@ import ( // FireworksProvider implements the Provider interface for Fireworks AI's API. type FireworksProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -41,6 +43,7 @@ func NewFireworksProvider(config *schemas.ProviderConfig, logger schemas.Logger) client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.fireworks.ai/inference" @@ -50,6 +53,7 @@ func NewFireworksProvider(config *schemas.ProviderConfig, logger schemas.Logger) return &FireworksProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -96,14 +100,14 @@ func (provider *FireworksProvider) TextCompletion(ctx *schemas.BifrostContext, k } // TextCompletionStream performs a streaming text completion request to the Fireworks AI API. -func (provider *FireworksProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if v := key.Value.GetValue(); v != "" { authHeader = map[string]string{"Authorization": "Bearer " + v} } return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, authHeader, @@ -116,6 +120,7 @@ func (provider *FireworksProvider) TextCompletionStream(ctx *schemas.BifrostCont nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -141,7 +146,7 @@ func (provider *FireworksProvider) ChatCompletion(ctx *schemas.BifrostContext, k // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Fireworks AI's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if v := key.Value.GetValue(); v != "" { authHeader = map[string]string{"Authorization": "Bearer " + v} @@ -149,7 +154,7 @@ func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostCont // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, authHeader, @@ -164,6 +169,7 @@ func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostCont nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -186,14 +192,14 @@ func (provider *FireworksProvider) Responses(ctx *schemas.BifrostContext, key sc } // ResponsesStream performs a streaming responses request to the Fireworks AI API. -func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if v := key.Value.GetValue(); v != "" { authHeader = map[string]string{"Authorization": "Bearer " + v} } return openai.HandleOpenAIResponsesStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"), request, authHeader, @@ -207,6 +213,7 @@ func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -243,7 +250,7 @@ func (provider *FireworksProvider) OCR(ctx *schemas.BifrostContext, key schemas. } // SpeechStream is not supported by the Fireworks AI provider. -func (provider *FireworksProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -253,7 +260,7 @@ func (provider *FireworksProvider) Transcription(ctx *schemas.BifrostContext, ke } // TranscriptionStream is not supported by the Fireworks AI provider. -func (provider *FireworksProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -263,7 +270,7 @@ func (provider *FireworksProvider) ImageGeneration(ctx *schemas.BifrostContext, } // ImageGenerationStream is not supported by the Fireworks AI provider. -func (provider *FireworksProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -273,7 +280,7 @@ func (provider *FireworksProvider) ImageEdit(ctx *schemas.BifrostContext, key sc } // ImageEditStream is not supported by the Fireworks AI provider. -func (provider *FireworksProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -423,6 +430,6 @@ func (provider *FireworksProvider) Passthrough(_ *schemas.BifrostContext, _ sche } // PassthroughStream is not supported by the Fireworks AI provider. -func (provider *FireworksProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *FireworksProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/fireworks/fireworks_test.go b/core/providers/fireworks/fireworks_test.go index 088ffb8693..445c6c83d0 100644 --- a/core/providers/fireworks/fireworks_test.go +++ b/core/providers/fireworks/fireworks_test.go @@ -399,7 +399,7 @@ func TestFireworksResponsesStreamUsesNativeResponsesEndpoint(t *testing.T) { return result, err } - stream, err := provider.ResponsesStream(ctx, postHookRunner, key, &schemas.BifrostResponsesRequest{ + stream, err := provider.ResponsesStream(ctx, postHookRunner, nil, key, &schemas.BifrostResponsesRequest{ Provider: schemas.Fireworks, Model: "accounts/fireworks/models/deepseek-v3p2", Input: []schemas.ResponsesMessage{ diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index cad4216534..521b05223c 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -26,19 +26,14 @@ const ( type GeminiProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse customProviderConfig *schemas.CustomProviderConfig // Custom provider config } -func buildStreamingResponseClient(base *fasthttp.Client) *fasthttp.Client { - client := providerUtils.CloneFastHTTPClientConfig(base) - client.StreamResponseBody = true - return client -} - func setGeminiRequestBody(req *fasthttp.Request, bodyReader io.Reader, bodySize int, jsonData []byte) { // Large payload mode streams request bytes directly from the ingress reader. // Normal mode sends marshaled JSON as before. @@ -72,6 +67,7 @@ func NewGeminiProvider(config *schemas.ProviderConfig, logger schemas.Logger) *G client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { @@ -82,6 +78,7 @@ func NewGeminiProvider(config *schemas.ProviderConfig, logger schemas.Logger) *G return &GeminiProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, customProviderConfig: config.CustomProviderConfig, sendBackRawRequest: config.SendBackRawRequest, @@ -266,7 +263,7 @@ func (provider *GeminiProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Gemini's API. // It formats the request, sends it to Gemini, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *GeminiProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -327,7 +324,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key // ChatCompletionStream performs a streaming chat completion request to the Gemini API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err @@ -364,7 +361,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Use shared Gemini streaming logic return HandleGeminiChatCompletionStream( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/models/"+request.Model+":streamGenerateContent?alt=sse"), jsonData, headers, @@ -376,6 +373,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -394,6 +392,7 @@ func HandleGeminiChatCompletionStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -415,9 +414,8 @@ func HandleGeminiChatCompletionStream( req.SetBody(jsonBody) } - // Make the request - streamingClient := buildStreamingResponseClient(client) - doErr := streamingClient.Do(req, resp) + // Make the request — caller is responsible for passing a streaming-configured client. + doErr := client.Do(req, resp) if doErr != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(doErr, context.Canceled) { @@ -458,12 +456,12 @@ func HandleGeminiChatCompletionStream( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -475,7 +473,7 @@ func HandleGeminiChatCompletionStream( fmt.Errorf("provider returned an empty response"), ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -534,7 +532,7 @@ func HandleGeminiChatCompletionStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) return } // Process chunk using shared function @@ -551,7 +549,7 @@ func HandleGeminiChatCompletionStream( }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } logger.Warn("Failed to process chunk: %v", err) @@ -570,7 +568,7 @@ func HandleGeminiChatCompletionStream( response, bifrostErr, isLastChunk := geminiResponse.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -605,12 +603,12 @@ func HandleGeminiChatCompletionStream( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) break } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } } }() @@ -823,7 +821,7 @@ func extractUsageFromResponsePrefetch(data []byte) *schemas.ResponsesResponseUsa } // ResponsesStream performs a streaming responses request to the Gemini API. -func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if responses stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err @@ -859,7 +857,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return HandleGeminiResponsesStream( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/models/"+request.Model+":streamGenerateContent?alt=sse"), jsonData, headers, @@ -871,6 +869,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, pos postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -889,6 +888,7 @@ func HandleGeminiResponsesStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -910,9 +910,8 @@ func HandleGeminiResponsesStream( req.SetBody(jsonBody) } - // Make the request - streamingClient := buildStreamingResponseClient(client) - doErr := streamingClient.Do(req, resp) + // Make the request — caller is responsible for passing a streaming-configured client. + doErr := client.Do(req, resp) if doErr != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(doErr, context.Canceled) { @@ -952,12 +951,12 @@ func HandleGeminiResponsesStream( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -976,6 +975,7 @@ func HandleGeminiResponsesStream( providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, + postHookSpanFinalizer, ) return } @@ -1036,7 +1036,7 @@ func HandleGeminiResponsesStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) return } @@ -1054,7 +1054,7 @@ func HandleGeminiResponsesStream( }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } logger.Warn("Failed to process chunk: %v", err) @@ -1070,7 +1070,7 @@ func HandleGeminiResponsesStream( responses, bifrostErr := geminiResponse.ToBifrostResponsesStream(sequenceNumber, streamState) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -1109,7 +1109,7 @@ func HandleGeminiResponsesStream( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) return } @@ -1119,7 +1119,7 @@ func HandleGeminiResponsesStream( } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) } } } @@ -1155,7 +1155,7 @@ func HandleGeminiResponsesStream( ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil, nil), responseChan, postHookSpanFinalizer) } }() @@ -1357,7 +1357,7 @@ func (provider *GeminiProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key } // SpeechStream performs a streaming speech synthesis request to the Gemini API. -func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if speech stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err @@ -1400,8 +1400,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } // Make the request - streamingClient := buildStreamingResponseClient(provider.client) - err := streamingClient.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -1443,12 +1442,12 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1488,7 +1487,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -1510,7 +1509,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } provider.logger.Warn("Failed to process chunk: %v", err) @@ -1569,7 +1568,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan, postHookSpanFinalizer) } } response := &schemas.BifrostSpeechStreamResponse{ @@ -1586,7 +1585,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan, postHookSpanFinalizer) }() return responseChan, nil @@ -1647,7 +1646,7 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s } // TranscriptionStream performs a streaming speech-to-text request to the Gemini API. -func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if transcription stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { return nil, err @@ -1690,8 +1689,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } // Make the request - streamingClient := buildStreamingResponseClient(provider.client) - err := streamingClient.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -1733,12 +1731,12 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1778,7 +1776,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -1799,7 +1797,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } provider.logger.Warn("Failed to process chunk: %v", err) @@ -1852,7 +1850,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer) } } response := &schemas.BifrostTranscriptionStreamResponse{ @@ -1875,7 +1873,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer) }() @@ -2035,7 +2033,7 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost } // ImageGenerationStream is not supported by the Gemini provider. -func (provider *GeminiProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -2179,7 +2177,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } // ImageEditStream is not supported by the Gemini provider. -func (provider *GeminiProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GeminiProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -4117,6 +4115,7 @@ func (provider *GeminiProvider) Passthrough( func (provider *GeminiProvider) PassthroughStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -4154,7 +4153,7 @@ func (provider *GeminiProvider) PassthroughStream( fasthttpReq.SetBody(req.Body) - activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp) + activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp) if err := activeClient.Do(fasthttpReq, resp); err != nil { providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -4201,12 +4200,12 @@ func (provider *GeminiProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } close(ch) }() @@ -4244,8 +4243,8 @@ func (provider *GeminiProvider) PassthroughStream( }, } postHookRunner(ctx, finalResp, nil) - if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { - finalizer(ctx) + if postHookSpanFinalizer != nil { + postHookSpanFinalizer(ctx) } return } @@ -4255,7 +4254,7 @@ func (provider *GeminiProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer) return } } diff --git a/core/providers/gemini/list_models_single_payload_test.go b/core/providers/gemini/list_models_single_payload_test.go index 127bfabbd1..483b3a414d 100644 --- a/core/providers/gemini/list_models_single_payload_test.go +++ b/core/providers/gemini/list_models_single_payload_test.go @@ -48,7 +48,10 @@ func TestListModelsByKey_ParsesSingleModelPayload(t *testing.T) { ctx.SetValue(schemas.BifrostContextKeyURLPath, "/models/gemini-2.5-pro") key := schemas.Key{Value: *schemas.NewEnvVar("dummy-key")} - resp, err := provider.listModelsByKey(ctx, key, &schemas.BifrostListModelsRequest{Provider: schemas.Gemini}) + // Unfiltered=true bypasses the allowed/alias/blacklist filter pipeline so + // this test can focus on the single-model-payload parsing code path in + // listModelsByKey (gemini.go:215-220). + resp, err := provider.listModelsByKey(ctx, key, &schemas.BifrostListModelsRequest{Provider: schemas.Gemini, Unfiltered: true}) require.Nil(t, err) require.NotNil(t, resp) require.Len(t, resp.Data, 1) diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index f152a10e94..9667b989ff 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -2,6 +2,7 @@ package groq import ( + "context" "strings" "time" @@ -14,7 +15,8 @@ import ( // GroqProvider implements the Provider interface for Groq's API. type GroqProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -46,6 +48,7 @@ func NewGroqProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*Gr client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.groq.com/openai" @@ -55,6 +58,7 @@ func NewGroqProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*Gr return &GroqProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -89,7 +93,7 @@ func (provider *GroqProvider) TextCompletion(ctx *schemas.BifrostContext, key sc // TextCompletionStream performs a streaming text completion request to Groq's API. // It formats the request, sends it to Groq, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *GroqProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq") } @@ -115,7 +119,7 @@ func (provider *GroqProvider) ChatCompletion(ctx *schemas.BifrostContext, key sc // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Groq's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if v := key.Value.GetValue(); v != "" { authHeader = map[string]string{"Authorization": "Bearer " + v} @@ -123,7 +127,7 @@ func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/chat/completions", request, authHeader, @@ -138,6 +142,7 @@ func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -154,11 +159,12 @@ func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas } // ResponsesStream performs a streaming responses request to the Groq API. -func (provider *GroqProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -199,7 +205,7 @@ func (provider *GroqProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, } // SpeechStream is not supported by the Groq provider. -func (provider *GroqProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -222,7 +228,7 @@ func (provider *GroqProvider) Transcription(ctx *schemas.BifrostContext, key sch } // TranscriptionStream is not supported by the Groq provider. -func (provider *GroqProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -232,7 +238,7 @@ func (provider *GroqProvider) ImageGeneration(ctx *schemas.BifrostContext, key s } // ImageGenerationStream is not supported by the Groq provider. -func (provider *GroqProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -242,7 +248,7 @@ func (provider *GroqProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas } // ImageEditStream is not supported by the Groq provider. -func (provider *GroqProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -391,6 +397,6 @@ func (provider *GroqProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.K return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *GroqProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *GroqProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index 110f5d6574..38ddd84e9f 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -21,7 +21,8 @@ import ( // HuggingFaceProvider implements the Provider interface for Hugging Face's inference APIs. type HuggingFaceProvider struct { logger schemas.Logger - client *fasthttp.Client + client *fasthttp.Client // unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig sendBackRawResponse bool sendBackRawRequest bool @@ -89,6 +90,7 @@ func NewHuggingFaceProvider(config *schemas.ProviderConfig, logger schemas.Logge client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = defaultInferenceBaseURL } @@ -97,6 +99,7 @@ func NewHuggingFaceProvider(config *schemas.ProviderConfig, logger schemas.Logge return &HuggingFaceProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawResponse: config.SendBackRawResponse, sendBackRawRequest: config.SendBackRawRequest, @@ -442,7 +445,7 @@ func (provider *HuggingFaceProvider) TextCompletion(ctx *schemas.BifrostContext, return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } -func (provider *HuggingFaceProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -529,7 +532,7 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, return bifrostResponse, nil } -func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -569,7 +572,7 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostCo // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/chat/completions", schemas.ChatCompletionStreamRequest), request, authHeader, @@ -584,6 +587,7 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostCo nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -602,7 +606,7 @@ func (provider *HuggingFaceProvider) Responses(ctx *schemas.BifrostContext, key return response, nil } -func (provider *HuggingFaceProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } @@ -611,6 +615,7 @@ func (provider *HuggingFaceProvider) ResponsesStream(ctx *schemas.BifrostContext return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -789,7 +794,7 @@ func (provider *HuggingFaceProvider) OCR(ctx *schemas.BifrostContext, key schema return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey()) } -func (provider *HuggingFaceProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -894,7 +899,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, } // TranscriptionStream is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -984,7 +989,7 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext // ImageGenerationStream handles streaming for fal-ai image generation. // Only fal-ai inference provider supports streaming for HuggingFace. -func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ImageGenerationStreamRequest); err != nil { return nil, err } @@ -1056,7 +1061,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } // Make the request - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -1098,7 +1103,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer providerUtils.ReleaseStreamingResponse(resp) defer close(responseChan) @@ -1107,7 +1112,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC "Provider returned an empty response", fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -1148,7 +1153,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC fmt.Sprintf("Error reading fal-ai stream: %v", readErr), readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -1174,7 +1179,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC bifrostErr.Error.Message = errorResp.Error } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } } @@ -1225,7 +1230,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk), - responseChan) + responseChan, postHookSpanFinalizer) } } @@ -1259,7 +1264,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, finalChunk), - responseChan) + responseChan, postHookSpanFinalizer) } }() @@ -1354,7 +1359,7 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key // ImageEditStream handles streaming for fal-ai image edit. // Only fal-ai inference provider supports streaming for HuggingFace. -func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ImageEditStreamRequest); err != nil { return nil, err } @@ -1435,7 +1440,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } // Make the request - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -1477,7 +1482,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer providerUtils.ReleaseStreamingResponse(resp) defer close(responseChan) @@ -1486,7 +1491,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext "Provider returned an empty response", fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -1527,7 +1532,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext fmt.Sprintf("Error reading fal-ai stream: %v", readErr), readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -1553,7 +1558,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext bifrostErr.Error.Message = errorResp.Error } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } } @@ -1604,7 +1609,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk), - responseChan) + responseChan, postHookSpanFinalizer) } } @@ -1638,7 +1643,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, finalChunk), - responseChan) + responseChan, postHookSpanFinalizer) } }() @@ -1791,6 +1796,6 @@ func (provider *HuggingFaceProvider) Passthrough(_ *schemas.BifrostContext, _ sc return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *HuggingFaceProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/mistral/custom_provider_test.go b/core/providers/mistral/custom_provider_test.go index 9015230544..cd7278f721 100644 --- a/core/providers/mistral/custom_provider_test.go +++ b/core/providers/mistral/custom_provider_test.go @@ -32,7 +32,8 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { assert.Equal(t, "invalid request", bifrostErr.Error.Message) assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type) assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code) - assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider) + // Note: ExtraFields.Provider is populated by bifrost.go's dispatcher via + // PopulateExtraFields, not by ParseMistralError called in isolation. } func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) { @@ -95,7 +96,7 @@ func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetad return response, err } - stream, bifrostErr := provider.ChatCompletionStream(ctx, postHookRunner, schemas.Key{}, request) + stream, bifrostErr := provider.ChatCompletionStream(ctx, postHookRunner, nil, schemas.Key{}, request) require.Nil(t, bifrostErr) var firstResponse *schemas.BifrostChatResponse @@ -110,7 +111,9 @@ func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetad } require.NotNil(t, firstResponse) - assert.Equal(t, customMistralProviderName, firstResponse.ExtraFields.Provider) + // Note: ExtraFields.Provider on stream chunks is populated by bifrost.go's + // dispatcher via PopulateExtraFields, not by provider streaming methods + // called in isolation. require.NotNil(t, capturedRequest) assert.Equal(t, float64(32), capturedRequest["max_tokens"]) @@ -153,6 +156,7 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) require.Nil(t, bifrostErr) require.NotNil(t, response) - assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider) - assert.Equal(t, "codestral-embed", response.ExtraFields.ResolvedModelUsed) + // Note: ExtraFields.Provider and ResolvedModelUsed are populated by + // bifrost.go's dispatcher via PopulateExtraFields, not by provider + // methods called in isolation. } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 1999cbb5fb..7400914542 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -19,7 +19,8 @@ import ( // MistralProvider implements the Provider interface for Mistral's API. type MistralProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers customProviderConfig *schemas.CustomProviderConfig sendBackRawRequest bool // Whether to include raw request in BifrostResponse @@ -52,6 +53,7 @@ func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) * client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.mistral.ai" @@ -61,6 +63,7 @@ func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) * return &MistralProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, customProviderConfig: config.CustomProviderConfig, sendBackRawRequest: config.SendBackRawRequest, @@ -152,7 +155,7 @@ func (provider *MistralProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Mistral's API. // It formats the request, sends it to Mistral, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *MistralProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -192,7 +195,7 @@ func (provider *MistralProvider) ChatCompletion(ctx *schemas.BifrostContext, key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Mistral's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -200,7 +203,7 @@ func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContex // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/chat/completions", provider.normalizeChatRequestForConversion(request), authHeader, @@ -215,6 +218,7 @@ func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContex nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -231,11 +235,12 @@ func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key sche } // ResponsesStream performs a streaming responses request to the Mistral API. -func (provider *MistralProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -382,7 +387,7 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke } // SpeechStream is not supported by the Mistral provider. -func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -488,7 +493,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // TranscriptionStream performs a streaming transcription request to Mistral's API. // It creates a multipart form with the audio file and streams transcription events. // Returns a channel of BifrostStreamChunk objects containing transcription deltas. -func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Convert Bifrost request to Mistral format @@ -535,7 +540,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext req.SetBody(body.Bytes()) // Make the request - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -579,9 +584,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -598,7 +603,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer). stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) defer stopCancellation() - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) sseReader := providerUtils.GetSSEEventReader(ctx, reader) chunkIndex := -1 @@ -621,7 +626,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer) } break } @@ -633,7 +638,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } chunkIndex++ - provider.processTranscriptionStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan) + provider.processTranscriptionStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan, postHookSpanFinalizer) // Break on terminal stream indicator (covers both done events and error events // that processTranscriptionStreamEvent signals via context). if ended, _ := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator).(bool); ended { @@ -657,6 +662,7 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( startTime time.Time, lastChunkTime *time.Time, responseChan chan *schemas.BifrostStreamChunk, + postHookSpanFinalizer func(context.Context), ) { // Skip empty data if strings.TrimSpace(jsonData) == "" { @@ -670,7 +676,7 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } } @@ -714,7 +720,7 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( response.Type = schemas.TranscriptionStreamResponseTypeDone } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer) } // BatchCreate is not supported by Mistral provider. @@ -783,7 +789,7 @@ func (provider *MistralProvider) ImageGeneration(ctx *schemas.BifrostContext, ke } // ImageGenerationStream is not supported by the Mistral provider. -func (provider *MistralProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -793,7 +799,7 @@ func (provider *MistralProvider) ImageEdit(ctx *schemas.BifrostContext, key sche } // ImageEditStream is not supported by the Mistral provider. -func (provider *MistralProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -882,6 +888,6 @@ func (provider *MistralProvider) Passthrough(_ *schemas.BifrostContext, _ schema return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *MistralProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *MistralProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/mistral/transcription_test.go b/core/providers/mistral/transcription_test.go index f5b8b7d1c3..3615cb2436 100644 --- a/core/providers/mistral/transcription_test.go +++ b/core/providers/mistral/transcription_test.go @@ -471,8 +471,9 @@ func TestTranscriptionWithMockServer(t *testing.T) { assert.Equal(t, 3.5, *resp.Duration) require.NotNil(t, resp.Language) assert.Equal(t, "en", *resp.Language) - assert.Equal(t, schemas.TranscriptionRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) + // Provider and RequestType on ExtraFields are populated by + // bifrost.go's dispatcher via PopulateExtraFields, not by + // provider methods called in isolation. }, }, { @@ -770,7 +771,7 @@ func TestTranscriptionStreamWithMockServer(t *testing.T) { ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second) defer cancel() - streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request) + streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request) if tt.expectError { require.NotNil(t, err) @@ -837,7 +838,7 @@ func TestTranscriptionStreamNilInput(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - stream, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, tt.request) + stream, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, tt.request) require.NotNil(t, err) assert.Nil(t, stream) @@ -1250,7 +1251,7 @@ func TestTranscriptionStreamEdgeCases(t *testing.T) { ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second) defer cancel() - streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request) + streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request) if tt.expectError { tt.validateResult(t, nil, err) @@ -1319,7 +1320,7 @@ func TestTranscriptionStreamContextCancellation(t *testing.T) { ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request) + streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request) require.Nil(t, err) require.NotNil(t, streamChan) @@ -1532,8 +1533,8 @@ func TestMistralTranscriptionIntegration(t *testing.T) { assert.NotNil(t, resp) // TODO: Send a proper audio file with speech to validate resp.Text is non-empty // assert.NotEmpty(t, resp.Text) - assert.Equal(t, schemas.TranscriptionRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) + // Note: ExtraFields.Provider/RequestType are populated by bifrost.go's + // dispatcher, not by provider methods called in isolation. t.Logf(" Transcribed text: %s", resp.Text) } @@ -1576,7 +1577,7 @@ func TestMistralTranscriptionStreamIntegration(t *testing.T) { } t.Log("🎤 Testing Mistral streaming transcription with voxtral-mini-latest...") - streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request) + streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request) if err != nil { // Log the error but don't fail - the minimal audio may not be valid for Mistral @@ -1622,8 +1623,8 @@ func TestMistralTranscriptionStreamIntegration(t *testing.T) { t.Logf(" Total chunks received: %d", chunkCount) t.Logf(" Transcribed text: %s", allText) - if lastResponse != nil { - assert.Equal(t, schemas.TranscriptionStreamRequest, lastResponse.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, lastResponse.ExtraFields.Provider) - } + // Note: ExtraFields.Provider/RequestType on stream chunks are populated + // by bifrost.go's dispatcher, not by provider streaming methods called + // in isolation. + _ = lastResponse } diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index eac617df6e..13e2cb4e33 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -2,6 +2,7 @@ package nebius import ( + "context" "fmt" "net/http" "net/url" @@ -17,7 +18,8 @@ import ( // NebiusProvider implements the Provider interface for Nebius's API. type NebiusProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -44,6 +46,7 @@ func NewNebiusProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.tokenfactory.nebius.com" @@ -53,6 +56,7 @@ func NewNebiusProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* return &NebiusProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -102,7 +106,7 @@ func (provider *NebiusProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Nebius's API. // It formats the request, sends it to Nebius, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *NebiusProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -110,7 +114,7 @@ func (provider *NebiusProvider) TextCompletionStream(ctx *schemas.BifrostContext // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, authHeader, @@ -123,6 +127,7 @@ func (provider *NebiusProvider) TextCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -159,7 +164,7 @@ func (provider *NebiusProvider) ChatCompletion(ctx *schemas.BifrostContext, key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Nebius's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *NebiusProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -168,7 +173,7 @@ func (provider *NebiusProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, authHeader, @@ -183,6 +188,7 @@ func (provider *NebiusProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -198,11 +204,12 @@ func (provider *NebiusProvider) Responses(ctx *schemas.BifrostContext, key schem } // ResponsesStream performs a streaming responses request to the Nebius API. -func (provider *NebiusProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -242,7 +249,7 @@ func (provider *NebiusProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key } // SpeechStream is not supported by the Nebius provider. -func (provider *NebiusProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -252,7 +259,7 @@ func (provider *NebiusProvider) Transcription(ctx *schemas.BifrostContext, key s } // TranscriptionStream is not supported by the Nebius provider. -func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -364,7 +371,7 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key } // ImageGenerationStream is not supported by Nebius provider. -func (provider *NebiusProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -374,7 +381,7 @@ func (provider *NebiusProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } // ImageEditStream is not supported by the Nebius provider. -func (provider *NebiusProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -523,6 +530,6 @@ func (provider *NebiusProvider) Passthrough(_ *schemas.BifrostContext, _ schemas return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *NebiusProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *NebiusProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index b84d3f7c9c..68ae599a7b 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -3,6 +3,7 @@ package ollama import ( + "context" "strings" "time" @@ -15,7 +16,8 @@ import ( // OllamaProvider implements the Provider interface for Ollama's API. type OllamaProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -47,12 +49,14 @@ func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") // BaseURL is optional when keys have ollama_key_config with per-key URLs return &OllamaProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -112,10 +116,10 @@ func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Ollama's API. // It formats the request, sends it to Ollama, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, @@ -128,6 +132,7 @@ func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -153,11 +158,11 @@ func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Ollama's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, @@ -172,6 +177,7 @@ func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -188,11 +194,12 @@ func (provider *OllamaProvider) Responses(ctx *schemas.BifrostContext, key schem } // ResponsesStream performs a streaming responses request to the Ollama API. -func (provider *OllamaProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -231,7 +238,7 @@ func (provider *OllamaProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key } // SpeechStream is not supported by the Ollama provider. -func (provider *OllamaProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -241,7 +248,7 @@ func (provider *OllamaProvider) Transcription(ctx *schemas.BifrostContext, key s } // TranscriptionStream is not supported by the Ollama provider. -func (provider *OllamaProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -251,7 +258,7 @@ func (provider *OllamaProvider) ImageGeneration(ctx *schemas.BifrostContext, key } // ImageGenerationStream is not supported by the Ollama provider. -func (provider *OllamaProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -261,7 +268,7 @@ func (provider *OllamaProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } // ImageEditStream is not supported by the Ollama provider. -func (provider *OllamaProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -409,6 +416,6 @@ func (provider *OllamaProvider) Passthrough(_ *schemas.BifrostContext, _ schemas return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *OllamaProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OllamaProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index a4e06dac47..8833252738 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -25,7 +25,8 @@ import ( // OpenAIProvider implements the Provider interface for OpenAI's GPT API. type OpenAIProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -59,6 +60,7 @@ func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *O client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.openai.com" @@ -68,6 +70,7 @@ func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *O return &OpenAIProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -383,7 +386,7 @@ func HandleOpenAITextCompletionRequest( // TextCompletionStream performs a streaming text completion request to OpenAI's API. // It formats the request, sends it to OpenAI, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil { return nil, err } @@ -393,7 +396,7 @@ func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext } return HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/completions", schemas.TextCompletionStreamRequest), request, authHeader, @@ -406,6 +409,7 @@ func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -426,6 +430,7 @@ func HandleOpenAITextCompletionStreaming( customResponseHandler responseHandler[schemas.BifrostTextCompletionResponse], postResponseConverter func(*schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { headers := map[string]string{ "Content-Type": "application/json", @@ -524,12 +529,12 @@ func HandleOpenAITextCompletionStreaming( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -551,7 +556,7 @@ func HandleOpenAITextCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -578,7 +583,7 @@ func HandleOpenAITextCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) return } break @@ -596,7 +601,7 @@ func HandleOpenAITextCompletionStreaming( handlerErr.ExtraFields.RawResponse = rawResponse } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, handlerErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, handlerErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } else { @@ -608,7 +613,7 @@ func HandleOpenAITextCompletionStreaming( if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } @@ -691,7 +696,7 @@ func HandleOpenAITextCompletionStreaming( response.ExtraFields.RawResponse = jsonData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(&response, nil, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(&response, nil, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } // For providers that don't send [DONE] marker break on finish_reason @@ -714,7 +719,7 @@ func HandleOpenAITextCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) }() return responseChan, nil @@ -888,7 +893,7 @@ func HandleOpenAIChatCompletionRequest( // ChatCompletionStream handles streaming for OpenAI chat completions. // It formats messages, prepares request body, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. -func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err @@ -907,7 +912,7 @@ func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Use shared streaming logic return HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/chat/completions", schemas.ChatCompletionStreamRequest), request, authHeader, @@ -922,6 +927,7 @@ func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -944,6 +950,7 @@ func HandleOpenAIChatCompletionStreaming( postRequestConverter func(*OpenAIChatRequest) *OpenAIChatRequest, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if the request is a redirect from ResponsesStream to ChatCompletionStream isResponsesToChatCompletionsFallback := false @@ -1060,12 +1067,12 @@ func HandleOpenAIChatCompletionStreaming( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } // Release the responses stream state if it was acquired (for ResponsesToChatCompletions fallback) schemas.ReleaseChatToResponsesStreamState(responsesStreamState) @@ -1089,7 +1096,7 @@ func HandleOpenAIChatCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -1120,7 +1127,7 @@ func HandleOpenAIChatCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) return } break @@ -1134,7 +1141,7 @@ func HandleOpenAIChatCompletionStreaming( if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } @@ -1153,7 +1160,7 @@ func HandleOpenAIChatCompletionStreaming( handlerErr.ExtraFields.RawResponse = rawResponse } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, handlerErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, handlerErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } else { @@ -1189,7 +1196,7 @@ func HandleOpenAIChatCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -1206,14 +1213,14 @@ func HandleOpenAIChatCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) return } response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer) } } else { if postResponseConverter != nil { @@ -1299,7 +1306,7 @@ func HandleOpenAIChatCompletionStreaming( response.ExtraFields.RawResponse = jsonData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, &response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, &response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } // For providers that don't send [DONE] marker break on finish_reason @@ -1324,7 +1331,7 @@ func HandleOpenAIChatCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer) } }() @@ -1495,7 +1502,7 @@ func HandleOpenAIResponsesRequest( } // ResponsesStream performs a streaming responses request to the OpenAI API. -func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err @@ -1514,7 +1521,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Use shared streaming logic return HandleOpenAIResponsesStreaming( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesStreamRequest), request, authHeader, @@ -1528,6 +1535,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -1549,6 +1557,7 @@ func HandleOpenAIResponsesStreaming( postRequestConverter func(*OpenAIResponsesRequest) *OpenAIResponsesRequest, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided) headers := map[string]string{ @@ -1648,12 +1657,12 @@ func HandleOpenAIResponsesStreaming( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1675,7 +1684,7 @@ func HandleOpenAIResponsesStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -1697,7 +1706,7 @@ func HandleOpenAIResponsesStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -1716,7 +1725,7 @@ func HandleOpenAIResponsesStreaming( bifrostErr.ExtraFields.RawResponse = rawResponse } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } else { @@ -1763,7 +1772,7 @@ func HandleOpenAIResponsesStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, []byte(jsonData), sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, []byte(jsonData), sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -1780,7 +1789,7 @@ func HandleOpenAIResponsesStreaming( bifrostErr.Error.Code = &response.Response.Error.Code } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, []byte(jsonData), sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, []byte(jsonData), sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } @@ -1792,14 +1801,14 @@ func HandleOpenAIResponsesStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan, postHookSpanFinalizer) return } response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan, postHookSpanFinalizer) } } }() @@ -2097,7 +2106,7 @@ func HandleOpenAISpeechRequest( // SpeechStream handles streaming for speech synthesis. // It formats the request body, creates HTTP request, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. -func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err } @@ -2115,7 +2124,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo return HandleOpenAISpeechStreamRequest( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/audio/speech", schemas.SpeechStreamRequest), request, authHeader, @@ -2127,6 +2136,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -2146,6 +2156,7 @@ func HandleOpenAISpeechStreamRequest( postRequestConverter func(*OpenAISpeechRequest) *OpenAISpeechRequest, postResponseConverter func(*schemas.BifrostSpeechStreamResponse) *schemas.BifrostSpeechStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -2242,12 +2253,12 @@ func HandleOpenAISpeechStreamRequest( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -2269,7 +2280,7 @@ func HandleOpenAISpeechStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -2293,7 +2304,7 @@ func HandleOpenAISpeechStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -2306,7 +2317,7 @@ func HandleOpenAISpeechStreamRequest( if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } @@ -2346,11 +2357,11 @@ func HandleOpenAISpeechStreamRequest( } response.BackfillParams(request) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan, postHookSpanFinalizer) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan, postHookSpanFinalizer) } }() @@ -2541,7 +2552,7 @@ func HandleOpenAITranscriptionRequest( } // TranscriptionStream performs a streaming transcription request to the OpenAI API. -func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { return nil, err } @@ -2553,7 +2564,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, return HandleOpenAITranscriptionStreamRequest( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/audio/transcriptions", schemas.TranscriptionStreamRequest), request, authHeader, @@ -2566,6 +2577,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -2586,6 +2598,7 @@ func HandleOpenAITranscriptionStreamRequest( postRequestConverter func(*OpenAITranscriptionRequest) *OpenAITranscriptionRequest, postResponseConverter func(*schemas.BifrostTranscriptionStreamResponse) *schemas.BifrostTranscriptionStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) @@ -2678,12 +2691,12 @@ func HandleOpenAITranscriptionStreamRequest( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -2705,7 +2718,7 @@ func HandleOpenAITranscriptionStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -2730,7 +2743,7 @@ func HandleOpenAITranscriptionStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -2745,7 +2758,7 @@ func HandleOpenAITranscriptionStreamRequest( bifrostErr.ExtraFields.RawResponse = jsonData } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), []byte(jsonData), false, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), []byte(jsonData), false, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } else { @@ -2757,7 +2770,7 @@ func HandleOpenAITranscriptionStreamRequest( if bifrostErrVal.Error != nil && bifrostErrVal.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) respBody := append([]byte(nil), resp.Body()...) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } @@ -2798,11 +2811,11 @@ func HandleOpenAITranscriptionStreamRequest( response.Text = fullTranscriptionText } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer) } }() @@ -2960,6 +2973,7 @@ func HandleOpenAIImageGenerationRequest( func (provider *OpenAIProvider) ImageGenerationStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -2979,7 +2993,7 @@ func (provider *OpenAIProvider) ImageGenerationStream( // Use shared streaming logic return HandleOpenAIImageGenerationStreaming( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationStreamRequest), request, authHeader, @@ -2992,6 +3006,7 @@ func (provider *OpenAIProvider) ImageGenerationStream( nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -3010,6 +3025,7 @@ func HandleOpenAIImageGenerationStreaming( postRequestConverter func(*OpenAIImageGenerationRequest) *OpenAIImageGenerationRequest, postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Set headers headers := map[string]string{ @@ -3110,12 +3126,12 @@ func HandleOpenAIImageGenerationStreaming( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -3137,7 +3153,7 @@ func HandleOpenAIImageGenerationStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -3164,7 +3180,7 @@ func HandleOpenAIImageGenerationStreaming( if readErr != nil { if readErr != io.EOF { logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -3177,7 +3193,7 @@ func HandleOpenAIImageGenerationStreaming( if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } @@ -3210,7 +3226,7 @@ func HandleOpenAIImageGenerationStreaming( } } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger, postHookSpanFinalizer) return } @@ -3353,7 +3369,7 @@ func HandleOpenAIImageGenerationStreaming( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk), - responseChan) + responseChan, postHookSpanFinalizer) if isCompleted { return @@ -4208,7 +4224,7 @@ func HandleOpenAIImageEditRequest( } // ImageEditStream streams image edits via the OpenAI Images API. -func (provider *OpenAIProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenAIProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if image generation stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageEditStreamRequest); err != nil { return nil, err @@ -4221,7 +4237,7 @@ func (provider *OpenAIProvider) ImageEditStream(ctx *schemas.BifrostContext, pos return HandleOpenAIImageEditStreamRequest( ctx, - provider.client, + provider.streamingClient, provider.buildRequestURL(ctx, "/v1/images/edits", schemas.ImageEditStreamRequest), request, authHeader, @@ -4234,6 +4250,7 @@ func (provider *OpenAIProvider) ImageEditStream(ctx *schemas.BifrostContext, pos nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -4252,6 +4269,7 @@ func HandleOpenAIImageEditStreamRequest( postRequestConverter func(*OpenAIImageEditRequest) *OpenAIImageEditRequest, postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { reqBody := ToOpenAIImageEditRequest(request) if reqBody == nil { @@ -4341,12 +4359,12 @@ func HandleOpenAIImageEditStreamRequest( // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -4368,7 +4386,7 @@ func HandleOpenAIImageEditStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer) return } @@ -4395,7 +4413,7 @@ func HandleOpenAIImageEditStreamRequest( if readErr != nil { if readErr != io.EOF { logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -4408,7 +4426,7 @@ func HandleOpenAIImageEditStreamRequest( if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer) return } } @@ -4441,7 +4459,7 @@ func HandleOpenAIImageEditStreamRequest( } } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger, postHookSpanFinalizer) return } @@ -4580,7 +4598,7 @@ func HandleOpenAIImageEditStreamRequest( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk), - responseChan) + responseChan, postHookSpanFinalizer) if isCompleted { return @@ -6895,6 +6913,7 @@ func (provider *OpenAIProvider) Passthrough( func (provider *OpenAIProvider) PassthroughStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -6933,7 +6952,7 @@ func (provider *OpenAIProvider) PassthroughStream( fasthttpReq.SetBody(req.Body) - activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp) + activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp) startTime := time.Now() @@ -6983,12 +7002,12 @@ func (provider *OpenAIProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } close(ch) }() @@ -7026,8 +7045,8 @@ func (provider *OpenAIProvider) PassthroughStream( }, } postHookRunner(ctx, finalResp, nil) - if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { - finalizer(ctx) + if postHookSpanFinalizer != nil { + postHookSpanFinalizer(ctx) } return } @@ -7037,7 +7056,7 @@ func (provider *OpenAIProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer) return } } diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 63ae8f48e4..36e4ff0566 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -2,6 +2,7 @@ package openrouter import ( + "context" "fmt" "net/http" "strings" @@ -16,7 +17,8 @@ import ( // OpenRouterProvider implements the Provider interface for OpenRouter's API. type OpenRouterProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -43,6 +45,7 @@ func NewOpenRouterProvider(config *schemas.ProviderConfig, logger schemas.Logger client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://openrouter.ai/api" @@ -52,6 +55,7 @@ func NewOpenRouterProvider(config *schemas.ProviderConfig, logger schemas.Logger return &OpenRouterProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -277,7 +281,7 @@ func (provider *OpenRouterProvider) TextCompletion(ctx *schemas.BifrostContext, // TextCompletionStream performs a streaming text completion request to OpenRouter's API. // It formats the request, sends it to OpenRouter, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *OpenRouterProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string keyValue := key.Value.GetValue() if keyValue != "" { @@ -285,7 +289,7 @@ func (provider *OpenRouterProvider) TextCompletionStream(ctx *schemas.BifrostCon } return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/completions", request, authHeader, @@ -298,6 +302,7 @@ func (provider *OpenRouterProvider) TextCompletionStream(ctx *schemas.BifrostCon nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -323,7 +328,7 @@ func (provider *OpenRouterProvider) ChatCompletion(ctx *schemas.BifrostContext, // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses OpenRouter's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *OpenRouterProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string keyValue := key.Value.GetValue() if keyValue != "" { @@ -332,7 +337,7 @@ func (provider *OpenRouterProvider) ChatCompletionStream(ctx *schemas.BifrostCon // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, authHeader, @@ -347,6 +352,7 @@ func (provider *OpenRouterProvider) ChatCompletionStream(ctx *schemas.BifrostCon nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -369,7 +375,7 @@ func (provider *OpenRouterProvider) Responses(ctx *schemas.BifrostContext, key s } // ResponsesStream performs a streaming responses request to the OpenRouter API. -func (provider *OpenRouterProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string keyValue := key.Value.GetValue() if keyValue != "" { @@ -377,7 +383,7 @@ func (provider *OpenRouterProvider) ResponsesStream(ctx *schemas.BifrostContext, } return openai.HandleOpenAIResponsesStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"), request, authHeader, @@ -391,6 +397,7 @@ func (provider *OpenRouterProvider) ResponsesStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -427,7 +434,7 @@ func (provider *OpenRouterProvider) OCR(ctx *schemas.BifrostContext, key schemas } // SpeechStream is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -437,7 +444,7 @@ func (provider *OpenRouterProvider) Transcription(ctx *schemas.BifrostContext, k } // TranscriptionStream is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -447,7 +454,7 @@ func (provider *OpenRouterProvider) ImageGeneration(ctx *schemas.BifrostContext, } // ImageGenerationStream is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -457,7 +464,7 @@ func (provider *OpenRouterProvider) ImageEdit(ctx *schemas.BifrostContext, key s } // ImageEditStream is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -606,6 +613,6 @@ func (provider *OpenRouterProvider) Passthrough(_ *schemas.BifrostContext, _ sch return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *OpenRouterProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *OpenRouterProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index e03d891d38..ae4cb22ab7 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -3,6 +3,7 @@ package parasail import ( + "context" "strings" "time" @@ -15,7 +16,8 @@ import ( // ParasailProvider implements the Provider interface for Parasail's API. type ParasailProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -42,6 +44,7 @@ func NewParasailProvider(config *schemas.ProviderConfig, logger schemas.Logger) client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.parasail.io" @@ -51,6 +54,7 @@ func NewParasailProvider(config *schemas.ProviderConfig, logger schemas.Logger) return &ParasailProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -85,7 +89,7 @@ func (provider *ParasailProvider) TextCompletion(ctx *schemas.BifrostContext, ke // TextCompletionStream performs a streaming text completion request to Parasail's API. // It formats the request, sends it to Parasail, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *ParasailProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -111,7 +115,7 @@ func (provider *ParasailProvider) ChatCompletion(ctx *schemas.BifrostContext, ke // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Parasail's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *ParasailProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -119,7 +123,7 @@ func (provider *ParasailProvider) ChatCompletionStream(ctx *schemas.BifrostConte // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/chat/completions", request, authHeader, @@ -134,6 +138,7 @@ func (provider *ParasailProvider) ChatCompletionStream(ctx *schemas.BifrostConte nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -150,11 +155,12 @@ func (provider *ParasailProvider) Responses(ctx *schemas.BifrostContext, key sch } // ResponsesStream performs a streaming responses request to the Parasail API. -func (provider *ParasailProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -171,7 +177,7 @@ func (provider *ParasailProvider) Speech(ctx *schemas.BifrostContext, key schema } // SpeechStream is not supported by the Parasail provider. -func (provider *ParasailProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -181,7 +187,7 @@ func (provider *ParasailProvider) Transcription(ctx *schemas.BifrostContext, key } // TranscriptionStream is not supported by the Parasail provider. -func (provider *ParasailProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -201,7 +207,7 @@ func (provider *ParasailProvider) ImageGeneration(ctx *schemas.BifrostContext, k } // ImageGenerationStream is not supported by the Parasail provider. -func (provider *ParasailProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -211,7 +217,7 @@ func (provider *ParasailProvider) ImageEdit(ctx *schemas.BifrostContext, key sch } // ImageEditStream is not supported by the Parasail provider. -func (provider *ParasailProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -361,6 +367,6 @@ func (provider *ParasailProvider) Passthrough(_ *schemas.BifrostContext, _ schem } // PassthroughStream is not supported by the Parasail provider. -func (provider *ParasailProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ParasailProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index f0b21ec21d..addb6a5fb8 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -3,6 +3,7 @@ package perplexity import ( + "context" "fmt" "net/http" "strings" @@ -17,7 +18,8 @@ import ( // PerplexityProvider implements the Provider interface for Perplexity's API. type PerplexityProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -44,6 +46,7 @@ func NewPerplexityProvider(config *schemas.ProviderConfig, logger schemas.Logger client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided if config.NetworkConfig.BaseURL == "" { config.NetworkConfig.BaseURL = "https://api.perplexity.ai" @@ -53,6 +56,7 @@ func NewPerplexityProvider(config *schemas.ProviderConfig, logger schemas.Logger return &PerplexityProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -129,7 +133,7 @@ func (provider *PerplexityProvider) TextCompletion(ctx *schemas.BifrostContext, // TextCompletionStream performs a streaming text completion request to Perplexity's API. // It formats the request, sends it to Perplexity, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *PerplexityProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -180,7 +184,7 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Perplexity's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *PerplexityProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -193,7 +197,7 @@ func (provider *PerplexityProvider) ChatCompletionStream(ctx *schemas.BifrostCon // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/chat/completions", request, authHeader, @@ -208,6 +212,7 @@ func (provider *PerplexityProvider) ChatCompletionStream(ctx *schemas.BifrostCon nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -224,11 +229,12 @@ func (provider *PerplexityProvider) Responses(ctx *schemas.BifrostContext, key s } // ResponsesStream performs a streaming responses request to the Perplexity API. -func (provider *PerplexityProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -255,7 +261,7 @@ func (provider *PerplexityProvider) OCR(ctx *schemas.BifrostContext, key schemas } // SpeechStream is not supported by the Perplexity provider. -func (provider *PerplexityProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -265,7 +271,7 @@ func (provider *PerplexityProvider) Transcription(ctx *schemas.BifrostContext, k } // TranscriptionStream is not supported by the Perplexity provider. -func (provider *PerplexityProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -275,7 +281,7 @@ func (provider *PerplexityProvider) ImageGeneration(ctx *schemas.BifrostContext, } // ImageGenerationStream is not supported by the Perplexity provider. -func (provider *PerplexityProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -285,7 +291,7 @@ func (provider *PerplexityProvider) ImageEdit(ctx *schemas.BifrostContext, key s } // ImageEditStream is not supported by the Perplexity provider. -func (provider *PerplexityProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -434,6 +440,6 @@ func (provider *PerplexityProvider) Passthrough(_ *schemas.BifrostContext, _ sch return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *PerplexityProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *PerplexityProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index aedf38471c..2a987ccec2 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -24,7 +24,8 @@ import ( // ReplicateProvider implements the Provider interface for Replicate's API. type ReplicateProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -52,6 +53,7 @@ func NewReplicateProvider(config *schemas.ProviderConfig, logger schemas.Logger) client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") if config.NetworkConfig.BaseURL == "" { @@ -61,6 +63,7 @@ func NewReplicateProvider(config *schemas.ProviderConfig, logger schemas.Logger) return &ReplicateProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -500,7 +503,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k // TextCompletionStream performs a streaming text completion request to replicate's API. // It formats the request, sends it to replicate, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil { return nil, err } @@ -559,7 +562,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont streamURL := *prediction.URLs.Stream // Connect to stream URL - _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.client, streamURL, key) + _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key) if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -581,12 +584,12 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -632,7 +635,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) } break } @@ -680,7 +683,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) } case "done": @@ -700,7 +703,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) // Explicitly close the body stream to terminate connection to Replicate resp.CloseBodyStream() return @@ -715,7 +718,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) // Explicitly close the body stream to terminate connection to Replicate resp.CloseBodyStream() return @@ -740,7 +743,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(finalResponse, nil, nil, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) resp.CloseBodyStream() return } @@ -839,7 +842,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k // ChatCompletionStream performs a streaming chat completion request to the replicate API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -898,7 +901,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont streamURL := *prediction.URLs.Stream // Connect to stream URL - _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.client, streamURL, key) + _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key) if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -920,12 +923,12 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -971,7 +974,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) } break } @@ -1026,7 +1029,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) } case "done": @@ -1046,7 +1049,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) // Explicitly close the body stream to terminate connection to Replicate resp.CloseBodyStream() return @@ -1061,7 +1064,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) // Explicitly close the body stream to terminate connection to Replicate resp.CloseBodyStream() return @@ -1097,7 +1100,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, finalResponse, nil, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) resp.CloseBodyStream() return } @@ -1191,7 +1194,7 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc } // ResponsesStream performs a streaming responses request to the replicate API. -func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } @@ -1268,7 +1271,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) // Make the streaming request - streamErr := provider.client.Do(req, resp) + streamErr := provider.streamingClient.Do(req, resp) if streamErr != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(streamErr, context.Canceled) { @@ -1314,12 +1317,12 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, // Registered first so the post-hook span finalizer runs on every exit // path — including the empty-reader early return below, which would // otherwise skip any finalizer declared later in this goroutine. - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1334,7 +1337,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, "provider returned an empty response", fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger, postHookSpanFinalizer) return } @@ -1387,7 +1390,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) return } break @@ -1430,7 +1433,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, createdResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ hasEmittedCreated = true } @@ -1450,7 +1453,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, inProgressResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ hasEmittedInProgress = true } @@ -1479,7 +1482,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemAddedResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ hasEmittedOutputItemAdded = true } @@ -1507,7 +1510,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partAddedResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ hasEmittedContentPartAdded = true } @@ -1527,7 +1530,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, deltaResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ hasReceivedContent = true } @@ -1553,7 +1556,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, textDoneResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ // response.content_part.done @@ -1576,7 +1579,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partDoneResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ // response.output_item.done @@ -1610,7 +1613,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemDoneResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) sequenceNumber++ } @@ -1643,7 +1646,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, completedResp, nil, nil, nil), - responseChan) + responseChan, postHookSpanFinalizer) resp.CloseBodyStream() return case "error": @@ -1668,7 +1671,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) resp.CloseBodyStream() return } @@ -1700,7 +1703,7 @@ func (provider *ReplicateProvider) OCR(ctx *schemas.BifrostContext, key schemas. } // SpeechStream is not supported by the replicate provider. -func (provider *ReplicateProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -1710,7 +1713,7 @@ func (provider *ReplicateProvider) Transcription(ctx *schemas.BifrostContext, ke } // TranscriptionStream is not supported by the replicate provider. -func (provider *ReplicateProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1807,7 +1810,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, // ImageGenerationStream performs a streaming image generation request to the replicate API. // It creates a prediction with streaming enabled and listens to the stream URL for progressive updates. -func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ImageGenerationStreamRequest); err != nil { return nil, err } @@ -1872,7 +1875,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon streamURL := *prediction.URLs.Stream // Connect to stream URL - _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.client, streamURL, key) + _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key) if bifrostErr != nil { return nil, bifrostErr } @@ -1894,12 +1897,12 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -1947,7 +1950,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading SSE stream: %v", readErr)) enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) } break } @@ -2011,7 +2014,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk), - responseChan) + responseChan, postHookSpanFinalizer) case "done": // Parse done event data @@ -2034,7 +2037,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( @@ -2046,7 +2049,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -2079,7 +2082,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, finalChunk), - responseChan) + responseChan, postHookSpanFinalizer) return case "error": @@ -2110,7 +2113,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } } @@ -2212,7 +2215,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc // ImageEditStream performs a streaming image edit request to the replicate API. // It creates a prediction with streaming enabled and listens to the stream URL for progressive updates. -func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.ImageEditStreamRequest); err != nil { return nil, err } @@ -2278,7 +2281,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, streamURL := *prediction.URLs.Stream // Connect to stream URL - _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.client, streamURL, key) + _, resp, bifrostErr := listenToReplicateStreamURL(ctx, provider.streamingClient, streamURL, key) if bifrostErr != nil { return nil, bifrostErr } @@ -2300,12 +2303,12 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer) } close(responseChan) }() @@ -2352,7 +2355,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, } enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("stream read error", readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger, postHookSpanFinalizer) } break } @@ -2414,7 +2417,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk), - responseChan) + responseChan, postHookSpanFinalizer) case "done": // Parse done event data @@ -2436,7 +2439,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, } bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( @@ -2447,7 +2450,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, } bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } @@ -2476,7 +2479,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, finalChunk), - responseChan) + responseChan, postHookSpanFinalizer) return case "error": @@ -2496,7 +2499,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, } bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger, postHookSpanFinalizer) return } } @@ -3284,6 +3287,6 @@ func (provider *ReplicateProvider) Passthrough(_ *schemas.BifrostContext, _ sche return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *ReplicateProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *ReplicateProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/runway/runway.go b/core/providers/runway/runway.go index 6bb5e31e32..e51e5e6355 100644 --- a/core/providers/runway/runway.go +++ b/core/providers/runway/runway.go @@ -3,6 +3,7 @@ package runway import ( + "context" "fmt" "net/http" "strings" @@ -75,7 +76,7 @@ func (provider *RunwayProvider) TextCompletion(ctx *schemas.BifrostContext, key } // TextCompletionStream is not supported by the Runway provider. -func (provider *RunwayProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -85,7 +86,7 @@ func (provider *RunwayProvider) ChatCompletion(ctx *schemas.BifrostContext, key } // ChatCompletionStream is not supported by the Runway provider. -func (provider *RunwayProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionStreamRequest, provider.GetProviderKey()) } @@ -95,7 +96,7 @@ func (provider *RunwayProvider) Responses(ctx *schemas.BifrostContext, key schem } // ResponsesStream is not supported by the Runway provider. -func (provider *RunwayProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesStreamRequest, provider.GetProviderKey()) } @@ -110,7 +111,7 @@ func (provider *RunwayProvider) Speech(ctx *schemas.BifrostContext, key schemas. } // SpeechStream is not supported by the Runway provider. -func (provider *RunwayProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -120,7 +121,7 @@ func (provider *RunwayProvider) Transcription(ctx *schemas.BifrostContext, key s } // TranscriptionStream is not supported by the Runway provider. -func (provider *RunwayProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -140,7 +141,7 @@ func (provider *RunwayProvider) ImageGeneration(ctx *schemas.BifrostContext, key } // ImageGenerationStream is not supported by the Runway provider. -func (provider *RunwayProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -150,7 +151,7 @@ func (provider *RunwayProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } // ImageEditStream is not supported by the Runway provider. -func (provider *RunwayProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -560,6 +561,6 @@ func (provider *RunwayProvider) Passthrough(_ *schemas.BifrostContext, _ schemas return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *RunwayProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *RunwayProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index 5b07356851..cd6b61809c 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -3,6 +3,7 @@ package sgl import ( + "context" "strings" "time" @@ -15,7 +16,8 @@ import ( // SGLProvider implements the Provider interface for SGL's API. type SGLProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -47,12 +49,14 @@ func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGL client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") // BaseURL is optional when keys have sgl_key_config with per-key URLs return &SGLProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -113,10 +117,10 @@ func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key sch // TextCompletionStream performs a streaming text completion request to SGL's API. // It formats the request, sends it to SGL, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, @@ -129,6 +133,7 @@ func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, p nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -154,11 +159,11 @@ func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses SGL's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, @@ -173,6 +178,7 @@ func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, p nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -189,11 +195,12 @@ func (provider *SGLProvider) Responses(ctx *schemas.BifrostContext, key schemas. } // ResponsesStream performs a streaming responses request to the SGL API. -func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -232,7 +239,7 @@ func (provider *SGLProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, r } // SpeechStream is not supported by the SGL provider. -func (provider *SGLProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -242,7 +249,7 @@ func (provider *SGLProvider) Transcription(ctx *schemas.BifrostContext, key sche } // TranscriptionStream is not supported by the SGL provider. -func (provider *SGLProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -252,7 +259,7 @@ func (provider *SGLProvider) ImageGeneration(ctx *schemas.BifrostContext, key sc } // ImageGenerationStream is not supported by the SGL provider. -func (provider *SGLProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -262,7 +269,7 @@ func (provider *SGLProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas. } // ImageEditStream is not supported by the SGL provider. -func (provider *SGLProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -411,6 +418,6 @@ func (provider *SGLProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Ke return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *SGLProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *SGLProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/utils/decompression_test.go b/core/providers/utils/decompression_test.go index 16ed30d608..1307340b06 100644 --- a/core/providers/utils/decompression_test.go +++ b/core/providers/utils/decompression_test.go @@ -496,13 +496,12 @@ func TestSafeReset(t *testing.T) { if ok { t.Fatal("expected false for panicking reset") } - t.Run("panic_nonnnil", func(t *testing.T) { + }) + + t.Run("panic_nonnil", func(t *testing.T) { ok := safeReset(func() error { panic("") }) if ok { - t.Fatal("expected false for nil panic") - } - if ok { - t.Fatal("expected false for nil panic") + t.Fatal("expected false for empty-string panic") } }) diff --git a/core/providers/utils/large_response.go b/core/providers/utils/large_response.go index e62d375c9a..c1c5da8a15 100644 --- a/core/providers/utils/large_response.go +++ b/core/providers/utils/large_response.go @@ -61,12 +61,19 @@ func (r *LargeResponseReader) Close() error { // BuildLargeResponseClient creates a streaming-enabled fasthttp client for large response detection. // The client caps buffering at the threshold and enables response body streaming. +// +// ReadTimeout/WriteTimeout/MaxConnDuration are zeroed: large-response bodies may take arbitrarily +// long to download, and fasthttp's ReadTimeout bounds *full* body read — not idle. Idle detection +// on stalled streams is handled separately (see NewIdleTimeoutReader / SetupStreamingPassthrough). func BuildLargeResponseClient(base *fasthttp.Client, responseThreshold int64) *fasthttp.Client { client := CloneFastHTTPClientConfig(base) if responseThreshold > 0 && responseThreshold <= int64(math.MaxInt) { client.MaxResponseBodySize = int(responseThreshold) } client.StreamResponseBody = true + client.ReadTimeout = 0 + client.WriteTimeout = 0 + client.MaxConnDuration = 0 return client } diff --git a/core/providers/utils/make_request_test.go b/core/providers/utils/make_request_test.go index ec2bf771bc..3a66ff986c 100644 --- a/core/providers/utils/make_request_test.go +++ b/core/providers/utils/make_request_test.go @@ -309,9 +309,9 @@ func TestNewBifrostTimeoutError(t *testing.T) { if err.Error.Message != "test timeout" { t.Fatalf("expected 'test timeout', got %s", err.Error.Message) } - if err.ExtraFields.Provider != "openai" { - t.Fatalf("expected provider openai, got %s", err.ExtraFields.Provider) - } + // Note: ExtraFields.Provider is populated by bifrost.go's dispatcher via + // PopulateExtraFields, not by NewBifrostTimeoutError — the constructor has + // no provider context. } func TestMakeRequestWithContext_ClientError(t *testing.T) { diff --git a/core/providers/utils/streaming_client_test.go b/core/providers/utils/streaming_client_test.go new file mode 100644 index 0000000000..0ed7878675 --- /dev/null +++ b/core/providers/utils/streaming_client_test.go @@ -0,0 +1,218 @@ +package utils + +import ( + "bufio" + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/valyala/fasthttp" +) + +// TestBuildStreamingClient_ZerosReadWriteTimeout verifies the streaming client +// has ReadTimeout=0 / WriteTimeout=0 / MaxConnDuration=0 while preserving other +// config from the base. +func TestBuildStreamingClient_ZerosReadWriteTimeout(t *testing.T) { + base := &fasthttp.Client{ + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + MaxConnDuration: 5 * time.Minute, + MaxConnWaitTimeout: 15 * time.Second, + MaxConnsPerHost: 123, + } + ConfigureDialer(base) + + stream := BuildStreamingClient(base) + + if stream.ReadTimeout != 0 { + t.Errorf("ReadTimeout: got %v, want 0", stream.ReadTimeout) + } + if stream.WriteTimeout != 0 { + t.Errorf("WriteTimeout: got %v, want 0", stream.WriteTimeout) + } + if stream.MaxConnDuration != 0 { + t.Errorf("MaxConnDuration: got %v, want 0", stream.MaxConnDuration) + } + if !stream.StreamResponseBody { + t.Error("StreamResponseBody: got false, want true") + } + if stream.MaxConnWaitTimeout != base.MaxConnWaitTimeout { + t.Errorf("MaxConnWaitTimeout should be preserved: got %v, want %v", + stream.MaxConnWaitTimeout, base.MaxConnWaitTimeout) + } + if stream.MaxConnsPerHost != base.MaxConnsPerHost { + t.Errorf("MaxConnsPerHost should be preserved: got %v, want %v", + stream.MaxConnsPerHost, base.MaxConnsPerHost) + } +} + +// TestBuildStreamingClient_BaseUnchanged verifies BuildStreamingClient does not +// mutate the base client (since unary callers still need the 30s timeout). +func TestBuildStreamingClient_BaseUnchanged(t *testing.T) { + base := &fasthttp.Client{ + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + MaxConnDuration: 5 * time.Minute, + } + _ = BuildStreamingClient(base) + + if base.ReadTimeout != 30*time.Second { + t.Errorf("base ReadTimeout mutated: got %v, want 30s", base.ReadTimeout) + } + if base.MaxConnDuration != 5*time.Minute { + t.Errorf("base MaxConnDuration mutated: got %v, want 5m", base.MaxConnDuration) + } +} + +// TestBuildStreamingClient_LongStreamSurvives verifies that a stream sending +// chunks every 500ms for 2.5s (total) is not killed by the base client's 1s +// ReadTimeout. Before the fix, fasthttp would abort at ~1s. +func TestBuildStreamingClient_LongStreamSurvives(t *testing.T) { + const chunkInterval = 500 * time.Millisecond + const totalChunks = 5 // 2.5s total, well past base ReadTimeout=1s + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + for i := 0; i < totalChunks; i++ { + fmt.Fprintf(w, "data: chunk-%d\n\n", i) + if flusher != nil { + flusher.Flush() + } + time.Sleep(chunkInterval) + } + })) + defer srv.Close() + + base := &fasthttp.Client{ + ReadTimeout: 1 * time.Second, // would abort the stream without the fix + WriteTimeout: 1 * time.Second, + } + ConfigureDialer(base) + stream := BuildStreamingClient(base) + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(srv.URL) + req.Header.SetMethod(http.MethodGet) + resp.StreamBody = true + + if err := stream.Do(req, resp); err != nil { + t.Fatalf("Do: %v", err) + } + if resp.StatusCode() != http.StatusOK { + t.Fatalf("status: %d", resp.StatusCode()) + } + + scanner := bufio.NewScanner(resp.BodyStream()) + got := 0 + for scanner.Scan() { + if line := scanner.Text(); len(line) >= 5 && line[:5] == "data:" { + got++ + } + } + if err := scanner.Err(); err != nil { + t.Fatalf("scanner: %v", err) + } + if got != totalChunks { + t.Errorf("chunks received: got %d, want %d (stream was likely killed early)", got, totalChunks) + } +} + +// TestBuildStreamingHTTPClient_ZerosTimeout verifies the net/http streaming +// client has Timeout=0 and shares the base's Transport. +func TestBuildStreamingHTTPClient_ZerosTimeout(t *testing.T) { + transport := &http.Transport{ResponseHeaderTimeout: 10 * time.Second} + base := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + stream := BuildStreamingHTTPClient(base) + + if stream.Timeout != 0 { + t.Errorf("Timeout: got %v, want 0", stream.Timeout) + } + if stream.Transport != base.Transport { + t.Error("Transport: streaming client should share base's Transport") + } + if base.Timeout != 30*time.Second { + t.Errorf("base Timeout mutated: got %v, want 30s", base.Timeout) + } +} + +// TestBuildStreamingHTTPClient_Nil verifies nil base returns empty client +// (not a panic). +func TestBuildStreamingHTTPClient_Nil(t *testing.T) { + stream := BuildStreamingHTTPClient(nil) + if stream == nil { + t.Fatal("BuildStreamingHTTPClient(nil) returned nil") + } + if stream.Timeout != 0 { + t.Errorf("Timeout: got %v, want 0", stream.Timeout) + } +} + +// TestBuildStreamingHTTPClient_LongStreamSurvives verifies that the streaming +// client can read a response body that takes longer than the base client's +// Timeout — proving Timeout=0 actually lifts the whole-request deadline. +func TestBuildStreamingHTTPClient_LongStreamSurvives(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + for i := 0; i < 4; i++ { + fmt.Fprintf(w, "data: chunk-%d\n\n", i) + if flusher != nil { + flusher.Flush() + } + time.Sleep(400 * time.Millisecond) + } + })) + defer srv.Close() + + base := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{Timeout: 5 * time.Second}).DialContext, + ResponseHeaderTimeout: 5 * time.Second, + }, + Timeout: 500 * time.Millisecond, // would abort the stream without the fix + } + stream := BuildStreamingHTTPClient(base) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("NewRequestWithContext: %v", err) + } + resp, err := stream.Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + got := 0 + for scanner.Scan() { + if line := scanner.Text(); len(line) >= 5 && line[:5] == "data:" { + got++ + } + } + if err := scanner.Err(); err != nil { + t.Fatalf("scanner: %v", err) + } + if got != 4 { + t.Errorf("chunks received: got %d, want 4 (stream was likely killed by Timeout)", got) + } +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 748c83210b..26e66ae40b 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -832,7 +832,6 @@ func CloneFastHTTPClientConfig(base *fasthttp.Client) *fasthttp.Client { DialTimeout: base.DialTimeout, Dial: base.Dial, TLSConfig: base.TLSConfig, - RetryIf: base.RetryIf, // nolint:staticcheck RetryIfErr: base.RetryIfErr, ConfigureClient: base.ConfigureClient, Name: base.Name, @@ -855,6 +854,43 @@ func CloneFastHTTPClientConfig(base *fasthttp.Client) *fasthttp.Client { } } +// BuildStreamingClient returns a fasthttp.Client suitable for long-lived SSE +// or EventStream responses. It clones base's dialer/proxy/TLS/pool settings, +// then clears Read/Write timeouts and MaxConnDuration so fasthttp does not +// pre-empt a healthy stream. StreamResponseBody is forced on. +// +// Per-chunk idle detection is enforced at the application layer via +// NewIdleTimeoutReader (see GetStreamIdleTimeout / StreamIdleTimeoutInSeconds). +// The initial TCP/TLS dial still honors the base client's ReadTimeout because +// the Dial closure installed by ConfigureDialer reads client.ReadTimeout from +// the base client pointer captured at ConfigureDialer call time — cloning copies +// that closure verbatim, so zeroing the clone's ReadTimeout does not affect dial. +func BuildStreamingClient(base *fasthttp.Client) *fasthttp.Client { + c := CloneFastHTTPClientConfig(base) + c.ReadTimeout = 0 + c.WriteTimeout = 0 + c.MaxConnDuration = 0 + c.StreamResponseBody = true + return c +} + +// BuildStreamingHTTPClient returns an *http.Client for long-lived streaming +// responses over net/http (e.g. Bedrock EventStream). It reuses the base's +// Transport (safe for concurrent use by multiple clients) and sets Timeout=0 +// so Client.Timeout does not cap the entire request lifecycle including body +// reads. The transport's ResponseHeaderTimeout still bounds the initial +// response-headers wait; per-chunk idle is enforced by NewIdleTimeoutReader. +func BuildStreamingHTTPClient(base *http.Client) *http.Client { + if base == nil { + return &http.Client{} + } + return &http.Client{ + Transport: base.Transport, + CheckRedirect: base.CheckRedirect, + Jar: base.Jar, + } +} + // decompressBodyStreamIfGzip checks Content-Encoding for gzip and wraps the stream // with on-the-fly decompression using a pooled gzip.Reader. Clears Content-Encoding // header so downstream consumers don't double-decompress. Returns original reader @@ -1666,7 +1702,7 @@ func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse b } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. -func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context)) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, @@ -1680,11 +1716,11 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: firstChunk, } - ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan) + ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan, postHookSpanFinalizer) } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event -func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context)) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, @@ -1698,7 +1734,7 @@ func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunn bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: chunk, } - ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan) + ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan, postHookSpanFinalizer) } // BuildClientStreamChunk constructs a BifrostStreamChunk from post-hook results. @@ -1811,6 +1847,7 @@ func ProcessAndSendResponse( postHookRunner schemas.PostHookRunner, response *schemas.BifrostResponse, responseChan chan *schemas.BifrostStreamChunk, + postHookSpanFinalizer func(context.Context), ) { // Accumulate chunk for tracing (common for all providers) if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil { @@ -1826,7 +1863,7 @@ func ProcessAndSendResponse( // Even if skipping, complete the deferred span if this is the final chunk if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { - completeDeferredSpan(ctx, processedResponse, processedError) + completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } return @@ -1843,7 +1880,7 @@ func ProcessAndSendResponse( // Check if this is the final chunk and complete deferred span with post-processed data if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { - completeDeferredSpan(ctx, processedResponse, processedError) + completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } } @@ -1859,6 +1896,7 @@ func ProcessAndSendBifrostError( bifrostErr *schemas.BifrostError, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) { // Run post hooks first so span reflects post-processed data processedResponse, processedError := postHookRunner(ctx, nil, bifrostErr) @@ -1867,7 +1905,7 @@ func ProcessAndSendBifrostError( // Even if skipping, complete the deferred span if this is the final chunk if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { - completeDeferredSpan(ctx, processedResponse, processedError) + completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } return @@ -1883,7 +1921,7 @@ func ProcessAndSendBifrostError( // Check if this is the final chunk and complete deferred span with post-processed data if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { - completeDeferredSpan(ctx, processedResponse, processedError) + completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } } @@ -1903,21 +1941,13 @@ func ProcessAndSendBifrostError( // // Panics inside the finalizer are recovered and logged so they never mask an // in-flight panic that triggered the defer. -func EnsureStreamFinalizerCalled(ctx context.Context) { - // Install the recover first so any panic — including one triggered by - // accessing ctx itself — is caught. This matters because this helper is - // called from `defer`, so a panic here would mask the in-flight panic - // that invoked the defer. +func EnsureStreamFinalizerCalled(ctx context.Context, finalizer func(context.Context)) { defer func() { if r := recover(); r != nil { getLogger().Debug("recovered panic in deferred stream finalizer: %v", r) } }() - if ctx == nil { - return - } - finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)) - if !ok || finalizer == nil { + if finalizer == nil { return } finalizer(ctx) @@ -2046,6 +2076,7 @@ func HandleStreamCancellation( postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) { // Check if already handled (StreamEndIndicator already set) if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { @@ -2063,7 +2094,7 @@ func HandleStreamCancellation( } // Send through PostHook chain - this updates the log to "error" status - ProcessAndSendBifrostError(ctx, postHookRunner, cancelErr, responseChan, logger) + ProcessAndSendBifrostError(ctx, postHookRunner, cancelErr, responseChan, logger, postHookSpanFinalizer) } // HandleStreamTimeout should be called when a streaming goroutine exits @@ -2079,6 +2110,7 @@ func HandleStreamTimeout( postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) { // Check if already handled (StreamEndIndicator already set) if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { @@ -2096,7 +2128,7 @@ func HandleStreamTimeout( } // Send through PostHook chain - this updates the log to "error" status - ProcessAndSendBifrostError(ctx, postHookRunner, timeoutErr, responseChan, logger) + ProcessAndSendBifrostError(ctx, postHookRunner, timeoutErr, responseChan, logger, postHookSpanFinalizer) } // ProcessAndSendError handles post-hook processing and sends the error to the channel. @@ -2109,6 +2141,7 @@ func ProcessAndSendError( err error, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, + postHookSpanFinalizer func(context.Context), ) { // Send scanner error through channel bifrostError := &schemas.BifrostError{ @@ -2610,7 +2643,7 @@ func GetBudgetTokensFromReasoningEffort( // This is called when the final chunk is processed (when StreamEndIndicator is true). // It retrieves the deferred span handle from TraceStore using the trace ID from context, // populates response attributes from accumulated chunks, and ends the span. -func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) { +func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError, postHookSpanFinalizer func(context.Context)) { if ctx == nil { return } @@ -2669,14 +2702,14 @@ func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostRe // Finalize aggregated post-hook spans before ending the LLM span // This creates one span per plugin with average execution time // We need to set the llm.call span ID in context so post-hook spans become its children - if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { + if postHookSpanFinalizer != nil { // Get the deferred span ID (the llm.call span) to set as parent for post-hook spans spanID := tracer.GetDeferredSpanID(traceID) if spanID != "" { finalizerCtx := context.WithValue(ctx, schemas.BifrostContextKeySpanID, spanID) - finalizer(finalizerCtx) + postHookSpanFinalizer(finalizerCtx) } else { - finalizer(ctx) + postHookSpanFinalizer(ctx) } } diff --git a/core/providers/utils/utils_test.go b/core/providers/utils/utils_test.go index e832980f4f..223d341509 100644 --- a/core/providers/utils/utils_test.go +++ b/core/providers/utils/utils_test.go @@ -1198,7 +1198,7 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChu } responseChan := make(chan *schemas.BifrostStreamChunk, 1) - ProcessAndSendResponse(ctx, passThrough, response, responseChan) + ProcessAndSendResponse(ctx, passThrough, response, responseChan, nil) chunk := <-responseChan if chunk.BifrostChatResponse == nil { @@ -1289,7 +1289,7 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk( responseChan := make(chan *schemas.BifrostStreamChunk, 1) ProcessAndSendResponse(ctx, errorRunner, &schemas.BifrostResponse{ ChatResponse: &schemas.BifrostChatResponse{ID: "chatcmpl-001"}, - }, responseChan) + }, responseChan, nil) chunk := <-responseChan if chunk.BifrostError == nil { @@ -1332,8 +1332,8 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk( // TestShouldSendBackRawRequest verifies that ShouldSendBackRawRequest correctly resolves // whether providers should capture the raw request body. It covers: // - Default (no context flags): returns the provider default -// - BifrostContextKeySendBackRawRequest=true in context: always returns true -// - Logging-only mode: requestWorker sets BifrostContextKeySendBackRawRequest=true, +// - BifrostContextKeyCaptureRawRequest=true in context: always returns true +// - Logging-only mode: requestWorker sets BifrostContextKeyCaptureRawRequest=true, // so the function sees a single flag (no second check needed). func TestShouldSendBackRawRequest(t *testing.T) { tests := []struct { @@ -1363,7 +1363,7 @@ func TestShouldSendBackRawRequest(t *testing.T) { want: true, }, { - // requestWorker sets BifrostContextKeySendBackRawRequest=true in logging-only + // requestWorker sets BifrostContextKeyCaptureRawRequest=true in logging-only // mode so a single flag covers both full send-back and logging-only cases. name: "logging-only: context SendBack=true set by requestWorker", contextSendBack: true, @@ -1376,7 +1376,7 @@ func TestShouldSendBackRawRequest(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) if tt.contextSendBack { - ctx.SetValue(schemas.BifrostContextKeySendBackRawRequest, true) + ctx.SetValue(schemas.BifrostContextKeyCaptureRawRequest, true) } got := ShouldSendBackRawRequest(ctx, tt.providerDefault) @@ -1416,7 +1416,7 @@ func TestShouldSendBackRawResponse(t *testing.T) { want: true, }, { - // requestWorker sets BifrostContextKeySendBackRawResponse=true in logging-only + // requestWorker sets BifrostContextKeyCaptureRawResponse=true in logging-only // mode so a single flag covers both full send-back and logging-only cases. name: "logging-only: context SendBack=true set by requestWorker", contextSendBack: true, @@ -1429,7 +1429,7 @@ func TestShouldSendBackRawResponse(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) if tt.contextSendBack { - ctx.SetValue(schemas.BifrostContextKeySendBackRawResponse, true) + ctx.SetValue(schemas.BifrostContextKeyCaptureRawResponse, true) } got := ShouldSendBackRawResponse(ctx, tt.providerDefault) diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 2fbe83979d..d373f58735 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -24,6 +24,7 @@ func (*VertexRankRequest) GetExtraParams() map[string]interface{} { const ( vertexDefaultRankingConfigID = "default_ranking_config" + vertexDefaultRerankModel = "semantic-ranker-default@latest" vertexMaxRerankRecordsPerQuery = 200 vertexSyntheticRecordPrefix = "idx:" ) diff --git a/core/providers/vertex/rerank.go b/core/providers/vertex/rerank.go index b06430fcac..257a1f8def 100644 --- a/core/providers/vertex/rerank.go +++ b/core/providers/vertex/rerank.go @@ -132,9 +132,11 @@ func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vert rankRequest.TopN = &topN } - if trimmedModel := strings.TrimSpace(bifrostReq.Model); trimmedModel != "" { - rankRequest.Model = &trimmedModel + trimmedModel := strings.TrimSpace(bifrostReq.Model) + if trimmedModel == "" { + trimmedModel = vertexDefaultRerankModel } + rankRequest.Model = &trimmedModel ignoreRecordDetailsInResponse := options.IgnoreRecordDetailsInResponse rankRequest.IgnoreRecordDetailsInResponse = &ignoreRecordDetailsInResponse diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 9a4792eb6a..4ffdda267d 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -74,7 +74,8 @@ func removeVertexClient(authCredentials string) { // VertexProvider implements the Provider interface for Google's Vertex AI API. type VertexProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -98,9 +99,11 @@ func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) return &VertexProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -356,7 +359,7 @@ func (provider *VertexProvider) TextCompletion(ctx *schemas.BifrostContext, key // TextCompletionStream performs a streaming text completion request to Vertex's API. // It formats the request, sends it to Vertex, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } @@ -675,7 +678,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // ChatCompletionStream performs a streaming chat completion request to the Vertex API. // It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models). // Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails. -func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { providerName := provider.GetProviderKey() projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { @@ -780,7 +783,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Use shared Anthropic streaming logic return anthropic.HandleAnthropicChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, completeURL, jsonData, headers, @@ -792,6 +795,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { // Use Gemini-style streaming for Gemini models @@ -859,7 +863,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Use shared streaming logic from Gemini return gemini.HandleGeminiChatCompletionStream( ctx, - provider.client, + provider.streamingClient, completeURL, jsonData, headers, @@ -871,6 +875,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } else { var authHeader map[string]string @@ -917,7 +922,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Use shared OpenAI streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, completeURL, request, authHeader, @@ -932,6 +937,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, provider.logger, + postHookSpanFinalizer, ) } } @@ -1207,7 +1213,7 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } // ResponsesStream performs a streaming responses request to the Vertex API. -func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if schemas.IsAnthropicModel(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { @@ -1252,7 +1258,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicResponsesStream( ctx, - provider.client, + provider.streamingClient, url, jsonBody, headers, @@ -1264,6 +1270,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { region := key.VertexKeyConfig.Region.GetValue() @@ -1340,7 +1347,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Use shared streaming logic from Gemini return gemini.HandleGeminiResponsesStream( ctx, - provider.client, + provider.streamingClient, completeURL, jsonData, headers, @@ -1352,12 +1359,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos postHookRunner, nil, provider.logger, + postHookSpanFinalizer, ) } else { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -1660,7 +1669,7 @@ func (provider *VertexProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key } // SpeechStream is not supported by the Vertex provider. -func (provider *VertexProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -1670,7 +1679,7 @@ func (provider *VertexProvider) Transcription(ctx *schemas.BifrostContext, key s } // TranscriptionStream is not supported by the Vertex provider. -func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1900,7 +1909,7 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } // ImageGenerationStream is not supported by the Vertex provider. -func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -2119,7 +2128,7 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } // ImageEditStream is not supported by the Vertex provider. -func (provider *VertexProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VertexProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -2943,6 +2952,7 @@ func (provider *VertexProvider) Passthrough( func (provider *VertexProvider) PassthroughStream( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, + postHookSpanFinalizer func(context.Context), key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -3046,7 +3056,7 @@ func (provider *VertexProvider) PassthroughStream( fasthttpReq.SetBody(req.Body) } - activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp) + activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp) if err := activeClient.Do(fasthttpReq, resp); err != nil { providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -3099,12 +3109,12 @@ func (provider *VertexProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer) } close(ch) }() @@ -3142,8 +3152,8 @@ func (provider *VertexProvider) PassthroughStream( }, } postHookRunner(ctx, finalResp, nil) - if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { - finalizer(ctx) + if postHookSpanFinalizer != nil { + postHookSpanFinalizer(ctx) } return } @@ -3153,7 +3163,7 @@ func (provider *VertexProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer) return } } diff --git a/core/providers/vllm/vllm.go b/core/providers/vllm/vllm.go index 548c6a0dc7..dfab966a26 100644 --- a/core/providers/vllm/vllm.go +++ b/core/providers/vllm/vllm.go @@ -20,7 +20,8 @@ import ( // VLLMProvider implements the Provider interface for vLLM's OpenAI-compatible API. type VLLMProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -44,12 +45,14 @@ func NewVLLMProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VL client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") // BaseURL is optional when keys have vllm_key_config with per-key URLs return &VLLMProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -137,7 +140,7 @@ func (provider *VLLMProvider) TextCompletion(ctx *schemas.BifrostContext, key sc } // TextCompletionStream performs a streaming text completion request to vLLM's API. -func (provider *VLLMProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { baseURL, bifrostErr := provider.baseURLOrError(key) if bifrostErr != nil { return nil, bifrostErr @@ -148,7 +151,7 @@ func (provider *VLLMProvider) TextCompletionStream(ctx *schemas.BifrostContext, } return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, authHeader, @@ -161,6 +164,7 @@ func (provider *VLLMProvider) TextCompletionStream(ctx *schemas.BifrostContext, HandleVLLMResponse, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -187,7 +191,7 @@ func (provider *VLLMProvider) ChatCompletion(ctx *schemas.BifrostContext, key sc } // ChatCompletionStream performs a streaming chat completion request to vLLM's API. -func (provider *VLLMProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { baseURL, bifrostErr := provider.baseURLOrError(key) if bifrostErr != nil { return nil, bifrostErr @@ -198,7 +202,7 @@ func (provider *VLLMProvider) ChatCompletionStream(ctx *schemas.BifrostContext, } return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, authHeader, @@ -213,6 +217,7 @@ func (provider *VLLMProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -248,11 +253,12 @@ func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas } // ResponsesStream performs a streaming responses request to vLLM's API (via chat completion stream). -func (provider *VLLMProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, + postHookSpanFinalizer, key, request.ToChatRequest(), ) @@ -397,7 +403,7 @@ func (provider *VLLMProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, } // SpeechStream is not supported by the vLLM provider. -func (provider *VLLMProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -422,7 +428,7 @@ func (provider *VLLMProvider) Transcription(ctx *schemas.BifrostContext, key sch } // TranscriptionStream performs a streaming transcription request to vLLM's API. -func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { baseURL, bifrostErr := provider.baseURLOrError(key) if bifrostErr != nil { return nil, bifrostErr @@ -475,7 +481,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p req.SetBody(body.Bytes()) // Make the request - err := provider.client.Do(req, resp) + err := provider.streamingClient.Do(req, resp) if err != nil { defer providerUtils.ReleaseStreamingResponse(resp) if errors.Is(err, context.Canceled) { @@ -517,12 +523,12 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Start streaming in a goroutine go func() { - defer providerUtils.EnsureStreamFinalizerCalled(ctx) + defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer) } close(responseChan) }() @@ -562,7 +568,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer) } break } @@ -580,7 +586,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p _, _, bifrostErr = HandleVLLMResponse(dataBytes, &response, nil, false, false) if bifrostErr != nil { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger, postHookSpanFinalizer) return } @@ -609,11 +615,11 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p response.ExtraFields.Latency = time.Since(startTime).Milliseconds() response.Text = fullTranscriptionText.String() ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan, postHookSpanFinalizer) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan, postHookSpanFinalizer) } }() @@ -627,7 +633,7 @@ func (provider *VLLMProvider) ImageGeneration(ctx *schemas.BifrostContext, key s } // ImageGenerationStream is not supported by the vLLM provider. -func (provider *VLLMProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -637,7 +643,7 @@ func (provider *VLLMProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas } // ImageEditStream is not supported by the vLLM provider. -func (provider *VLLMProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -786,6 +792,6 @@ func (provider *VLLMProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.K return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *VLLMProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *VLLMProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index 118b8589bf..e787f307fd 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -3,6 +3,7 @@ package xai import ( + "context" "strings" "time" @@ -15,7 +16,8 @@ import ( // xAIProvider implements the Provider interface for xAI's API. type XAIProvider struct { logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response) + streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader) networkConfig schemas.NetworkConfig // Network configuration including extra headers sendBackRawRequest bool // Whether to include raw request in BifrostResponse sendBackRawResponse bool // Whether to include raw response in BifrostResponse @@ -42,6 +44,7 @@ func NewXAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*XAI client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) client = providerUtils.ConfigureDialer(client) client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + streamingClient := providerUtils.BuildStreamingClient(client) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") if config.NetworkConfig.BaseURL == "" { @@ -51,6 +54,7 @@ func NewXAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*XAI return &XAIProvider{ logger: logger, client: client, + streamingClient: streamingClient, networkConfig: config.NetworkConfig, sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, @@ -101,10 +105,10 @@ func (provider *XAIProvider) TextCompletion(ctx *schemas.BifrostContext, key sch // TextCompletionStream performs a streaming text completion request to xAI's API. // It formats the request, sends it to xAI, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. -func (provider *XAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/completions", request, nil, @@ -117,6 +121,7 @@ func (provider *XAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, p nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -142,7 +147,7 @@ func (provider *XAIProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses xAI's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. -func (provider *XAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -150,7 +155,7 @@ func (provider *XAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, p // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+"/v1/chat/completions", request, authHeader, @@ -165,6 +170,7 @@ func (provider *XAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, p nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -187,14 +193,14 @@ func (provider *XAIProvider) Responses(ctx *schemas.BifrostContext, key schemas. } // ResponsesStream performs a streaming responses request to the xAI API. -func (provider *XAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} } return openai.HandleOpenAIResponsesStreaming( ctx, - provider.client, + provider.streamingClient, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"), request, authHeader, @@ -208,6 +214,7 @@ func (provider *XAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo nil, nil, provider.logger, + postHookSpanFinalizer, ) } @@ -232,7 +239,7 @@ func (provider *XAIProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, r } // SpeechStream is not supported by the xAI provider. -func (provider *XAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } @@ -242,7 +249,7 @@ func (provider *XAIProvider) Transcription(ctx *schemas.BifrostContext, key sche } // TranscriptionStream is not supported by the xAI provider. -func (provider *XAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -263,7 +270,7 @@ func (provider *XAIProvider) ImageGeneration(ctx *schemas.BifrostContext, key sc } // ImageGenerationStream is not supported by the xAI provider. -func (provider *XAIProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) } @@ -273,7 +280,7 @@ func (provider *XAIProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas. } // ImageEditStream is not supported by the xAI provider. -func (provider *XAIProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) } @@ -421,6 +428,6 @@ func (provider *XAIProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Ke return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) } -func (provider *XAIProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { +func (provider *XAIProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 04cca3915a..b932dc1aee 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -215,7 +215,6 @@ const ( BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func([]PluginLogEntry) (callback to complete trace after streaming, receives transport plugin logs - set by tracing middleware) - BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) BifrostContextKeyMCPUserID BifrostContextKey = "bifrost-mcp-user-id" // string (per-user OAuth user identifier from X-Bf-User-Id header) diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 5d28002f9f..29aa6c3302 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -2,6 +2,7 @@ package schemas import ( + "context" "encoding/json" "maps" "time" @@ -498,16 +499,19 @@ type Provider interface { ListModels(ctx *BifrostContext, keys []Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError) // TextCompletion performs a text completion request TextCompletion(ctx *BifrostContext, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError) - // TextCompletionStream performs a text completion stream request - TextCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostTextCompletionRequest) (chan *BifrostStreamChunk, *BifrostError) + // TextCompletionStream performs a text completion stream request. + // postHookSpanFinalizer is invoked by the provider's stream goroutine on stream completion + // (or on its panic-recovery defer) to finalize aggregated post-hook spans and release the + // per-attempt plugin pipeline. Pass nil if the caller does not need finalization. + TextCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTextCompletionRequest) (chan *BifrostStreamChunk, *BifrostError) // ChatCompletion performs a chat completion request ChatCompletion(ctx *BifrostContext, key Key, request *BifrostChatRequest) (*BifrostChatResponse, *BifrostError) // ChatCompletionStream performs a chat completion stream request - ChatCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostChatRequest) (chan *BifrostStreamChunk, *BifrostError) + ChatCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostChatRequest) (chan *BifrostStreamChunk, *BifrostError) // Responses performs a completion request using the Responses API (uses chat completion request internally for non-openai providers) Responses(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostResponsesResponse, *BifrostError) // ResponsesStream performs a completion request using the Responses API stream (uses chat completion stream request internally for non-openai providers) - ResponsesStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostResponsesRequest) (chan *BifrostStreamChunk, *BifrostError) + ResponsesStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostResponsesRequest) (chan *BifrostStreamChunk, *BifrostError) // CountTokens performs a count tokens request CountTokens(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError) // Embedding performs an embedding request @@ -519,21 +523,21 @@ type Provider interface { // Speech performs a text to speech request Speech(ctx *BifrostContext, key Key, request *BifrostSpeechRequest) (*BifrostSpeechResponse, *BifrostError) // SpeechStream performs a text to speech stream request - SpeechStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostSpeechRequest) (chan *BifrostStreamChunk, *BifrostError) + SpeechStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostSpeechRequest) (chan *BifrostStreamChunk, *BifrostError) // Transcription performs a transcription request Transcription(ctx *BifrostContext, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError) // TranscriptionStream performs a transcription stream request - TranscriptionStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostTranscriptionRequest) (chan *BifrostStreamChunk, *BifrostError) + TranscriptionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTranscriptionRequest) (chan *BifrostStreamChunk, *BifrostError) // ImageGeneration performs an image generation request ImageGeneration(ctx *BifrostContext, key Key, request *BifrostImageGenerationRequest) ( *BifrostImageGenerationResponse, *BifrostError) // ImageGenerationStream performs an image generation stream request - ImageGenerationStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, + ImageGenerationStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostImageGenerationRequest) (chan *BifrostStreamChunk, *BifrostError) // ImageEdit performs an image edit request ImageEdit(ctx *BifrostContext, key Key, request *BifrostImageEditRequest) (*BifrostImageGenerationResponse, *BifrostError) // ImageEditStream performs an image edit stream request - ImageEditStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, + ImageEditStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostImageEditRequest) (chan *BifrostStreamChunk, *BifrostError) // ImageVariation performs an image variation request ImageVariation(ctx *BifrostContext, key Key, request *BifrostImageVariationRequest) (*BifrostImageGenerationResponse, *BifrostError) @@ -592,7 +596,7 @@ type Provider interface { // Passthrough executes a non-streaming passthrough; body is fully buffered. Passthrough(ctx *BifrostContext, key Key, req *BifrostPassthroughRequest) (*BifrostPassthroughResponse, *BifrostError) // PassthroughStream executes a streaming passthrough, forwarding raw response bytes as BifrostStreamChunks. - PassthroughStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, req *BifrostPassthroughRequest) (chan *BifrostStreamChunk, *BifrostError) + PassthroughStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, req *BifrostPassthroughRequest) (chan *BifrostStreamChunk, *BifrostError) } // WebSocketCapableProvider is an optional interface that providers can implement diff --git a/framework/configstore/dlock_test.go b/framework/configstore/dlock_test.go index 019abc5624..16a4e892a9 100644 --- a/framework/configstore/dlock_test.go +++ b/framework/configstore/dlock_test.go @@ -90,10 +90,13 @@ func setupLockTestStore(t *testing.T) *RDBConfigStore { err = db.AutoMigrate(&tables.TableDistributedLock{}) require.NoError(t, err, "Failed to migrate test database") - return &RDBConfigStore{ - db: db, - logger: newMockLogger(), + s := &RDBConfigStore{logger: newMockLogger()} + s.db.Store(db) + s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, s.DB()) } + s.refreshPoolFn = func(ctx context.Context) error { return nil } + return s } // ============================================================================= @@ -241,7 +244,7 @@ func TestUpdateLockExpiry_ExpiredLock(t *testing.T) { ExpiresAt: time.Now().UTC().Add(-1 * time.Second), } // Directly insert the expired lock - err := store.db.Create(lock).Error + err := store.DB().Create(lock).Error require.NoError(t, err) // Try to extend expired lock @@ -327,11 +330,11 @@ func TestCleanupExpiredLocks_Success(t *testing.T) { } for _, l := range expiredLocks { - err := store.db.Create(&l).Error + err := store.DB().Create(&l).Error require.NoError(t, err) } for _, l := range validLocks { - err := store.db.Create(&l).Error + err := store.DB().Create(&l).Error require.NoError(t, err) } @@ -383,7 +386,7 @@ func TestCleanupExpiredLockByKey_Success(t *testing.T) { HolderID: "holder-1", ExpiresAt: time.Now().UTC().Add(-1 * time.Minute), } - err := store.db.Create(&lock).Error + err := store.DB().Create(&lock).Error require.NoError(t, err) // Cleanup specific expired lock @@ -505,7 +508,7 @@ func TestDistributedLockManager_CleanupExpiredLocks(t *testing.T) { HolderID: "holder-1", ExpiresAt: time.Now().UTC().Add(-1 * time.Minute), } - err := store.db.Create(&lock).Error + err := store.DB().Create(&lock).Error require.NoError(t, err) count, err := manager.CleanupExpiredLocks(ctx) @@ -565,7 +568,7 @@ func TestDistributedLock_TryLock_CleansUpExpired(t *testing.T) { HolderID: "old-holder", ExpiresAt: time.Now().UTC().Add(-1 * time.Minute), } - err := store.db.Create(&expiredLock).Error + err := store.DB().Create(&expiredLock).Error require.NoError(t, err) // New lock should be able to acquire after cleanup @@ -772,7 +775,7 @@ func TestDistributedLock_Extend_StolenLock(t *testing.T) { require.NoError(t, err) // Simulate lock being stolen by another process - err = store.db.Model(&tables.TableDistributedLock{}). + err = store.DB().Model(&tables.TableDistributedLock{}). Where("lock_key = ?", "test-lock"). Update("holder_id", "another-holder").Error require.NoError(t, err) @@ -844,7 +847,7 @@ func TestDistributedLock_IsHeld_StolenByAnotherHolder(t *testing.T) { require.NoError(t, err) // Simulate lock being stolen by another process - err = store.db.Model(&tables.TableDistributedLock{}). + err = store.DB().Model(&tables.TableDistributedLock{}). Where("lock_key = ?", "test-lock"). Update("holder_id", "another-holder").Error require.NoError(t, err) @@ -866,7 +869,7 @@ func TestDistributedLock_IsHeld_DeletedFromDB(t *testing.T) { require.NoError(t, err) // Delete lock directly from database - err = store.db.Where("lock_key = ?", "test-lock").Delete(&tables.TableDistributedLock{}).Error + err = store.DB().Where("lock_key = ?", "test-lock").Delete(&tables.TableDistributedLock{}).Error require.NoError(t, err) held, err := lock.IsHeld(ctx) diff --git a/framework/configstore/encryption.go b/framework/configstore/encryption.go index b8818cb9ba..b2de668abe 100644 --- a/framework/configstore/encryption.go +++ b/framework/configstore/encryption.go @@ -101,7 +101,7 @@ func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) var count int for { var keys []tables.TableKey - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&keys).Error; err != nil { @@ -110,7 +110,7 @@ func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) if len(keys) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range keys { if err := tx.Save(&keys[i]).Error; err != nil { return err @@ -131,7 +131,7 @@ func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, var count int for { var vks []tables.TableVirtualKey - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&vks).Error; err != nil { @@ -140,7 +140,7 @@ func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, if len(vks) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range vks { if err := tx.Save(&vks[i]).Error; err != nil { return err @@ -161,7 +161,7 @@ func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, err var count int for { var sessions []tables.SessionsTable - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&sessions).Error; err != nil { @@ -170,7 +170,7 @@ func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, err if len(sessions) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range sessions { if err := tx.Save(&sessions[i]).Error; err != nil { return err @@ -191,7 +191,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, var count int for { var tokens []tables.TableOauthToken - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&tokens).Error; err != nil { @@ -200,7 +200,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, if len(tokens) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range tokens { if err := tx.Save(&tokens[i]).Error; err != nil { return err @@ -221,7 +221,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, var count int for { var configs []tables.TableOauthConfig - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&configs).Error; err != nil { @@ -230,7 +230,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, if len(configs) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range configs { if err := tx.Save(&configs[i]).Error; err != nil { return err @@ -251,7 +251,7 @@ func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, e var count int for { var clients []tables.TableMCPClient - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&clients).Error; err != nil { @@ -260,7 +260,7 @@ func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, e if len(clients) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range clients { if err := tx.Save(&clients[i]).Error; err != nil { return err @@ -282,7 +282,7 @@ func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (i var count int for { var providers []tables.TableProvider - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&providers).Error; err != nil { @@ -291,7 +291,7 @@ func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (i if len(providers) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range providers { if err := tx.Save(&providers[i]).Error; err != nil { return err @@ -313,7 +313,7 @@ func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) var count int for { var configs []tables.TableVectorStoreConfig - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&configs).Error; err != nil { @@ -322,7 +322,7 @@ func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) if len(configs) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range configs { if err := tx.Save(&configs[i]).Error; err != nil { return err @@ -344,7 +344,7 @@ func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, erro var count int for { var plugins []tables.TablePlugin - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&plugins).Error; err != nil { @@ -353,7 +353,7 @@ func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, erro if len(plugins) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range plugins { if err := tx.Save(&plugins[i]).Error; err != nil { return err diff --git a/framework/configstore/encryption_test.go b/framework/configstore/encryption_test.go index 9ac36baede..e7ad6272b1 100644 --- a/framework/configstore/encryption_test.go +++ b/framework/configstore/encryption_test.go @@ -54,10 +54,12 @@ func setupEncryptionTestStore(t *testing.T) (*RDBConfigStore, *gorm.DB) { ) require.NoError(t, err) - store := &RDBConfigStore{ - db: db, - logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + store := &RDBConfigStore{logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo)} + store.db.Store(db) + store.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, store.DB()) } + store.refreshPoolFn = func(ctx context.Context) error { return nil } return store, db } diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index a0352382f8..05f33a3bec 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -72,6 +72,22 @@ func (l *migrationLock) release(ctx context.Context) { l.conn.Close() } +// RunSingleMigration applies a single gormigrate migration on the given +// *gorm.DB. Mirrors (*RDBConfigStore).RunMigration but takes the *gorm.DB +// directly, so downstream consumers (bifrost-enterprise, plugins) can run +// their migrations inside a MigrateOnFreshConnection callback without having +// to reach the throwaway pool through the ConfigStore abstraction. +func RunSingleMigration(ctx context.Context, db *gorm.DB, migration *migrator.Migration) error { + if db == nil { + return fmt.Errorf("db cannot be nil") + } + if migration == nil { + return fmt.Errorf("migration cannot be nil") + } + m := migrator.New(db.WithContext(ctx), migrator.DefaultOptions, []*migrator.Migration{migration}) + return m.Migrate() +} + // Migrate performs the necessary database migrations. func triggerMigrations(ctx context.Context, db *gorm.DB) error { // Acquire advisory lock to serialize migrations across cluster nodes. @@ -398,6 +414,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationNormalizeOtelTraceType(ctx, db); err != nil { return err } + if err := migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx, db); err != nil { + return err + } return nil } @@ -6543,3 +6562,72 @@ func migrationNormalizeOtelTraceType(ctx context.Context, db *gorm.DB) error { } return nil } + +// migrateCalendarAlignedToBudgetsAndRateLimitsTable +func migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "migrate_calendar_aligned", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + // Adding columns first + if !mig.HasColumn(&tables.TableBudget{}, "calendar_aligned") { + if err := mig.AddColumn(&tables.TableBudget{}, "calendar_aligned"); err != nil { + return fmt.Errorf("failed to add calendar_aligned column to budgets: %w", err) + } + } + // Adding columns first + if !mig.HasColumn(&tables.TableRateLimit{}, "calendar_aligned") { + if err := mig.AddColumn(&tables.TableRateLimit{}, "calendar_aligned"); err != nil { + return fmt.Errorf("failed to add calendar_aligned column to rate_limits: %w", err) + } + } + // Prefill calendar_aligned for existing budgets and rate_limits attached to virtual keys. + // GORM v2: Preload must precede the Find finisher, otherwise it's a no-op on the executed query. + var virtualKeys []tables.TableVirtualKey + if err := tx.Preload("Budgets").Find(&virtualKeys).Error; err != nil { + return fmt.Errorf("failed to load virtual keys: %w", err) + } + for i := range virtualKeys { + // Preserve the legacy per-VK semantic: only copy calendar_aligned=true to + // the VK's budgets and rate_limit when the source VK itself was aligned. + // Hardcoding true would change reset behavior for tenants whose VKs were + // never calendar-aligned. + if !virtualKeys[i].CalendarAligned { + continue + } + // Ratelimit updates. A stale rate_limit_id is skipped — the FK is intentionally + // not DB-enforced for TableVirtualKey — but the VK's budgets are still migrated. + if virtualKeys[i].RateLimitID != nil { + var rateLimit tables.TableRateLimit + err := tx.First(&rateLimit, virtualKeys[i].RateLimitID).Error + switch { + case err == gorm.ErrRecordNotFound: + // Skip only the rate-limit update; fall through to the budget loop. + case err != nil: + return fmt.Errorf("failed to load rate limit for virtual key %s: %w", virtualKeys[i].ID, err) + default: + rateLimit.CalendarAligned = true + if err := tx.Save(&rateLimit).Error; err != nil { + return fmt.Errorf("failed to save rate limit for virtual key %s: %w", virtualKeys[i].ID, err) + } + } + } + // Budgets update + for j := range virtualKeys[i].Budgets { + virtualKeys[i].Budgets[j].CalendarAligned = true + if err := tx.Save(&virtualKeys[i].Budgets[j]).Error; err != nil { + return fmt.Errorf("failed to save budget for virtual key %s: %w", virtualKeys[i].ID, err) + } + } + } + log.Printf("[Migration] Prefilled calendar_aligned field for existing budgets and rate limits") + return nil + }, + Rollback: func(tx *gorm.DB) error { return nil }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running migrate_calendar_aligned migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go index b03afaa7ff..329f1be98c 100644 --- a/framework/configstore/migrations_test.go +++ b/framework/configstore/migrations_test.go @@ -1122,10 +1122,12 @@ func setupFullMigrationDB(t *testing.T) (*RDBConfigStore, *gorm.DB) { err = triggerMigrations(ctx, db) require.NoError(t, err, "triggerMigrations should succeed on a fresh DB") - store := &RDBConfigStore{ - db: db, - logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + store := &RDBConfigStore{logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo)} + store.db.Store(db) + store.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, store.DB()) } + store.refreshPoolFn = func(ctx context.Context) error { return nil } return store, db } @@ -2002,3 +2004,250 @@ func TestMigrationReplaceEnableLiteLLMWithCompatColumns(t *testing.T) { assert.False(t, rows[1].CompatShouldConvertParams, "compat_should_convert_params should default to false") } +// setupCalendarAlignedPreMigrationDB creates a SQLite DB with governance_virtual_keys, +// governance_budgets, and governance_rate_limits tables, then drops the calendar_aligned +// column from budgets and rate_limits to simulate the pre-migration schema state. +func setupCalendarAlignedPreMigrationDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err, "Failed to create test database") + + require.NoError(t, db.AutoMigrate( + &tables.TableBudget{}, + &tables.TableRateLimit{}, + &tables.TableVirtualKey{}, + ), "Failed to auto-migrate governance tables") + + // Simulate pre-migration state: drop calendar_aligned from budgets and rate_limits. + mig := db.Migrator() + if mig.HasColumn(&tables.TableBudget{}, "calendar_aligned") { + require.NoError(t, mig.DropColumn(&tables.TableBudget{}, "calendar_aligned")) + } + if mig.HasColumn(&tables.TableRateLimit{}, "calendar_aligned") { + require.NoError(t, mig.DropColumn(&tables.TableRateLimit{}, "calendar_aligned")) + } + + require.NoError(t, db.Exec(`CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)`).Error) + return db +} + +// insertBudgetRaw inserts a budget via raw SQL so the pre-migration table (without +// calendar_aligned) accepts the row without tripping BeforeSave hooks. +func insertBudgetRaw(t *testing.T, db *gorm.DB, id string, vkID *string) { + t.Helper() + now := time.Now() + err := db.Exec(` + INSERT INTO governance_budgets + (id, max_limit, reset_duration, last_reset, current_usage, virtual_key_id, config_hash, created_at, updated_at) + VALUES (?, ?, ?, ?, 0, ?, '', ?, ?) + `, id, 100.0, "1d", now, vkID, now, now).Error + require.NoError(t, err, "Failed to insert budget %s", id) +} + +// insertRateLimitRaw inserts a rate_limit via raw SQL mirroring insertBudgetRaw. +func insertRateLimitRaw(t *testing.T, db *gorm.DB, id string) { + t.Helper() + now := time.Now() + tokenMax := int64(1000) + tokenDur := "1h" + err := db.Exec(` + INSERT INTO governance_rate_limits + (id, token_max_limit, token_reset_duration, token_current_usage, token_last_reset, + request_current_usage, request_last_reset, config_hash, created_at, updated_at) + VALUES (?, ?, ?, 0, ?, 0, ?, '', ?, ?) + `, id, tokenMax, tokenDur, now, now, now, now).Error + require.NoError(t, err, "Failed to insert rate limit %s", id) +} + +// insertVKRaw inserts a virtual key via raw SQL so we can set rate_limit_id without +// running the VK BeforeSave hook (which handles encryption). +func insertVKRaw(t *testing.T, db *gorm.DB, id, name, value string, rateLimitID *string, calendarAligned bool) { + t.Helper() + now := time.Now() + err := db.Exec(` + INSERT INTO governance_virtual_keys + (id, name, value, is_active, rate_limit_id, calendar_aligned, encryption_status, value_hash, config_hash, created_at, updated_at) + VALUES (?, ?, ?, 1, ?, ?, 'plain_text', ?, '', ?, ?) + `, id, name, value, rateLimitID, calendarAligned, id+"-hash", now, now).Error + require.NoError(t, err, "Failed to insert virtual key %s", id) +} + +// TestMigrationCalendarAligned_AddColumnsAndBackfill exercises the full migration: +// column addition on governance_budgets and governance_rate_limits, plus backfill +// of calendar_aligned=true for rows attached to any virtual key. Rows NOT attached +// to a VK must remain at the default (false). +func TestMigrationCalendarAligned_AddColumnsAndBackfill(t *testing.T) { + db := setupCalendarAlignedPreMigrationDB(t) + ctx := context.Background() + + // Pre-condition: columns should be absent before running the migration. + mig := db.Migrator() + assert.False(t, mig.HasColumn(&tables.TableBudget{}, "calendar_aligned"), + "calendar_aligned should NOT exist on budgets before migration") + assert.False(t, mig.HasColumn(&tables.TableRateLimit{}, "calendar_aligned"), + "calendar_aligned should NOT exist on rate_limits before migration") + + // Seed: VK-aligned carries the legacy calendar_aligned=true flag; its rate limit + // and both attached budgets SHOULD be backfilled. + insertRateLimitRaw(t, db, "rl-aligned") + vkAlignedID := "vk-aligned" + rlAlignedID := "rl-aligned" + insertVKRaw(t, db, vkAlignedID, "vk-aligned", "vk-aligned-value", &rlAlignedID, true) + insertBudgetRaw(t, db, "budget-aligned-1", &vkAlignedID) + insertBudgetRaw(t, db, "budget-aligned-2", &vkAlignedID) + + // Seed: VK-unaligned carries calendar_aligned=false; its rate limit and + // budgets MUST NOT be backfilled (preserves legacy per-VK semantics). + insertRateLimitRaw(t, db, "rl-unaligned") + vkUnalignedID := "vk-unaligned" + rlUnalignedID := "rl-unaligned" + insertVKRaw(t, db, vkUnalignedID, "vk-unaligned", "vk-unaligned-value", &rlUnalignedID, false) + insertBudgetRaw(t, db, "budget-unaligned", &vkUnalignedID) + + // Seed: VK-aligned-no-rl has calendar_aligned=true but no rate limit — its + // single budget should still be backfilled. + vkAlignedNoRLID := "vk-aligned-no-rl" + insertVKRaw(t, db, vkAlignedNoRLID, "vk-aligned-no-rl", "vk-aligned-no-rl-value", nil, true) + insertBudgetRaw(t, db, "budget-aligned-no-rl", &vkAlignedNoRLID) + + // Seed: orphaned rows (no VK owner) — these must NOT be backfilled. + insertRateLimitRaw(t, db, "rl-orphan") + insertBudgetRaw(t, db, "budget-orphan", nil) + + // Run the migration. + require.NoError(t, migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx, db), + "migration should succeed") + + // Post-condition: columns exist. + assert.True(t, mig.HasColumn(&tables.TableBudget{}, "calendar_aligned"), + "calendar_aligned should exist on budgets after migration") + assert.True(t, mig.HasColumn(&tables.TableRateLimit{}, "calendar_aligned"), + "calendar_aligned should exist on rate_limits after migration") + + // Verify per-row state. + type budgetRow struct { + ID string + CalendarAligned bool + } + type rlRow struct { + ID string + CalendarAligned bool + } + + assertBudget := func(id string, expected bool) { + t.Helper() + var row budgetRow + err := db.Table("governance_budgets"). + Select("id, calendar_aligned"). + Where("id = ?", id).Scan(&row).Error + require.NoError(t, err) + assert.Equal(t, expected, row.CalendarAligned, + "budget %s calendar_aligned mismatch", id) + } + assertRateLimit := func(id string, expected bool) { + t.Helper() + var row rlRow + err := db.Table("governance_rate_limits"). + Select("id, calendar_aligned"). + Where("id = ?", id).Scan(&row).Error + require.NoError(t, err) + assert.Equal(t, expected, row.CalendarAligned, + "rate_limit %s calendar_aligned mismatch", id) + } + + // Rows attached to VKs with calendar_aligned=true should be backfilled. + assertBudget("budget-aligned-1", true) + assertBudget("budget-aligned-2", true) + assertBudget("budget-aligned-no-rl", true) + assertRateLimit("rl-aligned", true) + + // Rows attached to VKs with calendar_aligned=false must remain false — + // the migration must preserve the legacy per-VK semantic. + assertBudget("budget-unaligned", false) + assertRateLimit("rl-unaligned", false) + + // Orphaned rows should retain the default (false). + assertBudget("budget-orphan", false) + assertRateLimit("rl-orphan", false) +} + +// TestMigrationCalendarAligned_StaleRateLimitID verifies that a VK pointing at a +// deleted rate_limit row is skipped rather than aborting the migration (the FK +// is intentionally not DB-enforced on TableVirtualKey). +func TestMigrationCalendarAligned_StaleRateLimitID(t *testing.T) { + db := setupCalendarAlignedPreMigrationDB(t) + ctx := context.Background() + + // VK-stale carries calendar_aligned=true (so the migration touches it) and + // references a rate_limit_id that doesn't exist in the table. + stale := "rl-does-not-exist" + insertVKRaw(t, db, "vk-stale", "vk-stale", "vk-stale-value", &stale, true) + insertBudgetRaw(t, db, "budget-stale-vk", strPtr("vk-stale")) + + // VK-ok has calendar_aligned=true and a valid rate_limit so we can verify + // the loop keeps processing after the stale skip. + insertRateLimitRaw(t, db, "rl-ok") + vkOKID := "vk-ok" + rlOK := "rl-ok" + insertVKRaw(t, db, vkOKID, "vk-ok", "vk-ok-value", &rlOK, true) + insertBudgetRaw(t, db, "budget-ok", &vkOKID) + + // Migration must NOT return an error despite the stale reference. + require.NoError(t, migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx, db), + "stale rate_limit_id should be skipped, not abort the migration") + + // The VK with the stale ref still has its budgets backfilled. + var staleVKBudget bool + require.NoError(t, db.Table("governance_budgets"). + Select("calendar_aligned"). + Where("id = ?", "budget-stale-vk").Scan(&staleVKBudget).Error) + assert.True(t, staleVKBudget, "budget attached to the stale-ref VK should still be migrated") + + // Non-stale VK should have been fully processed. + var okRL bool + require.NoError(t, db.Table("governance_rate_limits"). + Select("calendar_aligned"). + Where("id = ?", "rl-ok").Scan(&okRL).Error) + assert.True(t, okRL, "valid rate_limit should be migrated even when a prior VK had a stale ref") +} + +// TestMigrationCalendarAligned_Idempotent ensures the migration can run twice +// without error — the migrator library must record the migration ID so the +// second invocation is a no-op. +func TestMigrationCalendarAligned_Idempotent(t *testing.T) { + db := setupCalendarAlignedPreMigrationDB(t) + ctx := context.Background() + + insertRateLimitRaw(t, db, "rl-1") + vkID := "vk-1" + rl := "rl-1" + insertVKRaw(t, db, vkID, "vk-1", "vk-1-value", &rl, true) + insertBudgetRaw(t, db, "budget-1", &vkID) + + require.NoError(t, migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx, db), + "first run should succeed") + require.NoError(t, migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx, db), + "second run should be a no-op via migrator ID tracking") + + // Data should still be correct after the second run. + var aligned bool + require.NoError(t, db.Table("governance_budgets"). + Select("calendar_aligned").Where("id = ?", "budget-1").Scan(&aligned).Error) + assert.True(t, aligned, "budget backfill should persist across idempotent reruns") +} + +// TestMigrationCalendarAligned_WiredIntoTriggerMigrations confirms the new +// migration is part of the startup chain so a fresh DB emerges with the column +// present on both governance_budgets and governance_rate_limits. +func TestMigrationCalendarAligned_WiredIntoTriggerMigrations(t *testing.T) { + _, db := setupFullMigrationDB(t) + mig := db.Migrator() + assert.True(t, mig.HasColumn(&tables.TableBudget{}, "calendar_aligned"), + "triggerMigrations should add calendar_aligned to governance_budgets") + assert.True(t, mig.HasColumn(&tables.TableRateLimit{}, "calendar_aligned"), + "triggerMigrations should add calendar_aligned to governance_rate_limits") +} + +func strPtr(s string) *string { return &s } diff --git a/framework/configstore/postgres.go b/framework/configstore/postgres.go index ecc016b68d..b88edf143b 100644 --- a/framework/configstore/postgres.go +++ b/framework/configstore/postgres.go @@ -21,12 +21,67 @@ type PostgresConfig struct { MaxOpenConns int `json:"max_open_conns"` } +// buildPostgresDSN assembles a libpq-style DSN from the validated config. +func buildPostgresDSN(config *PostgresConfig) string { + return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", + config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), + config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue()) +} + +// openPostresConnection opens a *gorm.DB against the configured Postgres instance +// using the shared bifrost logger. Used for both the throwaway migration pool +// and the runtime pool. +func openPostresConnection(dsn string, logger schemas.Logger) (*gorm.DB, error) { + return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{ + Logger: newGormLogger(logger), + }) +} + +// closeDbConn closes the *sql.DB backing a *gorm.DB, logging any error. +// Used in error paths and for the throwaway migration pool. +func closeDbConn(db *gorm.DB, logger schemas.Logger) { + sqlDB, err := db.DB() + if err != nil { + logger.Error("failed to resolve *sql.DB for close: %v", err) + return + } + if err := sqlDB.Close(); err != nil { + logger.Error("failed to close DB connection: %v", err) + } +} + +// applyPostgresPoolTuning applies MaxIdleConns / MaxOpenConns from config to +// the supplied *gorm.DB, falling back to defaults when the config leaves the +// field at zero. +func applyPostgresPoolTuning(db *gorm.DB, config *PostgresConfig) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + maxIdleConns := config.MaxIdleConns + if maxIdleConns == 0 { + maxIdleConns = 5 + } + sqlDB.SetMaxIdleConns(maxIdleConns) + maxOpenConns := config.MaxOpenConns + if maxOpenConns == 0 { + maxOpenConns = 50 + } + sqlDB.SetMaxOpenConns(maxOpenConns) + return nil +} + // newPostgresConfigStore creates a new Postgres config store. +// +// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not +// change result type"): a throwaway migration pool runs DDL and is closed +// immediately, then a fresh runtime pool is opened. The runtime pool's +// connections never see pre-migration schema, so their cached prepared-plans +// stay valid for the life of the process. func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (ConfigStore, error) { if config == nil { return nil, fmt.Errorf("config is required") } - // Validate required config if config.Host == nil || config.Host.GetValue() == "" { return nil, fmt.Errorf("postgres host is required") } @@ -45,53 +100,69 @@ func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger if config.SSLMode == nil || config.SSLMode.GetValue() == "" { return nil, fmt.Errorf("postgres ssl mode is required") } - dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue()) - db, err := gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - }), &gorm.Config{ - Logger: newGormLogger(logger), - }) + dsn := buildPostgresDSN(config) + + // Throwaway pool for schema migrations. Closing it before the runtime pool + // opens guarantees no cached prepared-plan survives the DDL. + mDb, err := openPostresConnection(dsn, logger) if err != nil { return nil, err } + if err := triggerMigrations(ctx, mDb); err != nil { + closeDbConn(mDb, logger) + return nil, err + } + closeDbConn(mDb, logger) - // Configure connection pool - sqlDB, err := db.DB() + // Runtime pool. Opens against post-migration schema. + db, err := openPostresConnection(dsn, logger) if err != nil { return nil, err } - // Set MaxIdleConns (default: 5) - maxIdleConns := config.MaxIdleConns - if maxIdleConns == 0 { - maxIdleConns = 5 + if err := applyPostgresPoolTuning(db, config); err != nil { + closeDbConn(db, logger) + return nil, err } - sqlDB.SetMaxIdleConns(maxIdleConns) - // Set MaxOpenConns (default: 50) - maxOpenConns := config.MaxOpenConns - if maxOpenConns == 0 { - maxOpenConns = 50 + d := &RDBConfigStore{logger: logger} + d.db.Store(db) + + // migrateOnFreshFn: downstream consumers (e.g. bifrost-enterprise) run + // their migrations via this hook on a throwaway pool that closes after fn. + d.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + tempDB, err := openPostresConnection(dsn, logger) + if err != nil { + return err + } + defer closeDbConn(tempDB, logger) + return fn(ctx, tempDB) } - sqlDB.SetMaxOpenConns(maxOpenConns) - d := &RDBConfigStore{db: db, logger: logger} - // Run migrations - if err := triggerMigrations(ctx, db); err != nil { - // Closing the DB connection - if sqlDB, dbErr := db.DB(); dbErr == nil { - if closeErr := sqlDB.Close(); closeErr != nil { - logger.Error("failed to close DB connection: %v", closeErr) - } + // refreshPoolFn: open fresh runtime pool first (so a failure leaves the + // existing pool in place), swap atomically, then close the old pool. + // sql.DB.Close blocks until in-flight queries finish, so callers already + // using the old pool complete safely. + d.refreshPoolFn = func(ctx context.Context) error { + newDB, err := openPostresConnection(dsn, logger) + if err != nil { + return fmt.Errorf("failed to open fresh runtime pool: %w", err) } - return nil, err + if err := applyPostgresPoolTuning(newDB, config); err != nil { + closeDbConn(newDB, logger) + return fmt.Errorf("failed to tune fresh runtime pool: %w", err) + } + oldDB := d.db.Swap(newDB) + if oldDB != nil { + closeDbConn(oldDB, logger) + } + return nil } - // Encrypt any plaintext rows if encryption is enabled + + // Encrypt any plaintext rows if encryption is enabled. Runs on the + // runtime pool — pure DML (SELECT + UPDATE), no DDL, so cached plans it + // installs remain valid until the next external migration batch. if err := d.EncryptPlaintextRows(ctx); err != nil { - if sqlDB, dbErr := db.DB(); dbErr == nil { - if closeErr := sqlDB.Close(); closeErr != nil { - logger.Error("failed to close DB connection: %v", closeErr) - } - } + closeDbConn(db, logger) return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err) } return d, nil diff --git a/framework/configstore/prompts.go b/framework/configstore/prompts.go index e760351b95..c30dacd75a 100644 --- a/framework/configstore/prompts.go +++ b/framework/configstore/prompts.go @@ -27,7 +27,7 @@ func isUniqueConstraintError(err error) bool { // GetFolders gets all folders func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) { var folders []tables.TableFolder - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Order("created_at DESC"). Find(&folders).Error; err != nil { return nil, err @@ -36,7 +36,7 @@ func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, // Get prompts count for each folder for i := range folders { var count int64 - if err := s.db.WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil { + if err := s.DB().WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil { return nil, err } folders[i].PromptsCount = int(count) @@ -48,7 +48,7 @@ func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, // GetFolderByID gets a folder by ID func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) { var folder tables.TableFolder - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). First(&folder, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound @@ -60,12 +60,12 @@ func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables. // CreateFolder creates a new folder func (s *RDBConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error { - return s.db.WithContext(ctx).Create(folder).Error + return s.DB().WithContext(ctx).Create(folder).Error } // UpdateFolder updates a folder func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error { - res := s.db.WithContext(ctx).Where("id = ?", folder.ID).Save(folder) + res := s.DB().WithContext(ctx).Where("id = ?", folder.ID).Save(folder) if res.Error != nil { return res.Error } @@ -79,7 +79,7 @@ func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableF // PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot // alter foreign key constraints after table creation. func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Check folder exists var folder tables.TableFolder if err := tx.First(&folder, "id = ?", id).Error; err != nil { @@ -90,7 +90,7 @@ func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error { } // PostgreSQL: ON DELETE CASCADE handles all child deletions - if s.db.Dialector.Name() == "postgres" { + if s.DB().Dialector.Name() == "postgres" { return tx.Delete(&folder).Error } @@ -135,7 +135,7 @@ func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error { // GetPrompts gets all prompts, optionally filtered by folder ID func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) { var prompts []tables.TablePrompt - query := s.db.WithContext(ctx). + query := s.DB().WithContext(ctx). Preload("Folder"). Order("created_at DESC") @@ -150,7 +150,7 @@ func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]ta // Get latest version for each prompt for i := range prompts { var latestVersion tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ? AND is_latest = ?", prompts[i].ID, true). First(&latestVersion).Error; err != nil { @@ -168,7 +168,7 @@ func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]ta // GetPromptByID gets a prompt by ID with latest version func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) { var prompt tables.TablePrompt - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Folder"). First(&prompt, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -179,7 +179,7 @@ func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables. // Get latest version var latestVersion tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ? AND is_latest = ?", prompt.ID, true). First(&latestVersion).Error; err != nil { @@ -195,13 +195,13 @@ func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables. // CreatePrompt creates a new prompt func (s *RDBConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { - return s.db.WithContext(ctx).Create(prompt).Error + return s.DB().WithContext(ctx).Create(prompt).Error } // UpdatePrompt updates a prompt func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { // Use Select to explicitly include FolderID so GORM writes NULL when it's nil - res := s.db.WithContext(ctx). + res := s.DB().WithContext(ctx). Model(prompt). Where("id = ?", prompt.ID). Select("Name", "FolderID", "UpdatedAt"). @@ -219,7 +219,7 @@ func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TableP // PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot // alter foreign key constraints after table creation. func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Check prompt exists var prompt tables.TablePrompt if err := tx.First(&prompt, "id = ?", id).Error; err != nil { @@ -230,7 +230,7 @@ func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { } // PostgreSQL: ON DELETE CASCADE handles all child deletions - if s.db.Dialector.Name() == "postgres" { + if s.DB().Dialector.Name() == "postgres" { return tx.Delete(&prompt).Error } @@ -258,7 +258,7 @@ func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { // GetAllPromptVersions returns every version across all prompts in a single query. func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { var versions []tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Order("prompt_id ASC, version_number DESC"). Find(&versions).Error; err != nil { @@ -270,7 +270,7 @@ func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.Tab // GetPromptVersions gets all versions for a prompt func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { var versions []tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ?", promptID). Order("version_number DESC"). @@ -283,7 +283,7 @@ func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) // GetPromptVersionByID gets a version by ID func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { var version tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Preload("Prompt"). First(&version, "id = ?", id).Error; err != nil { @@ -298,7 +298,7 @@ func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*ta // GetLatestPromptVersion gets the latest version for a prompt func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) { var version tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ? AND is_latest = ?", promptID, true). First(&version).Error; err != nil { @@ -315,7 +315,7 @@ func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID st func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error { const maxRetries = 3 for attempt := 0; attempt < maxRetries; attempt++ { - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Get the next version number var maxVersionNumber int if err := tx.Model(&tables.TablePromptVersion{}). @@ -364,7 +364,7 @@ func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *table // DeletePromptVersion deletes a version and promotes the previous version to latest if needed. // PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade. func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Get the version to check if it's latest var version tables.TablePromptVersion if err := tx.First(&version, "id = ?", id).Error; err != nil { @@ -375,7 +375,7 @@ func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error } // SQLite: manually delete version messages (PostgreSQL CASCADE handles this) - if s.db.Dialector.Name() != "postgres" { + if s.DB().Dialector.Name() != "postgres" { if err := tx.Where("version_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil { return err } @@ -413,7 +413,7 @@ func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error // GetPromptSessions gets all sessions for a prompt func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) { var sessions []tables.TablePromptSession - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Preload("Version"). Where("prompt_id = ?", promptID). @@ -427,7 +427,7 @@ func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) // GetPromptSessionByID gets a session by ID func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) { var session tables.TablePromptSession - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Preload("Prompt"). Preload("Version"). @@ -442,7 +442,7 @@ func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*ta // CreatePromptSession creates a new session func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Verify version belongs to the same prompt if set if session.VersionID != nil { var version tables.TablePromptVersion @@ -484,7 +484,7 @@ func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *table // UpdatePromptSession updates a session and its messages func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Verify version belongs to the same prompt if set if session.VersionID != nil { var version tables.TablePromptVersion @@ -530,7 +530,7 @@ func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *table // RenamePromptSession updates only the name of a session func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error { - result := s.db.WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name) + result := s.DB().WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name) if result.Error != nil { return result.Error } @@ -543,7 +543,7 @@ func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name // DeletePromptSession deletes a session and its messages. // PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade. func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var session tables.TablePromptSession if err := tx.First(&session, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -553,7 +553,7 @@ func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error } // PostgreSQL: ON DELETE CASCADE handles message deletion - if s.db.Dialector.Name() == "postgres" { + if s.DB().Dialector.Name() == "postgres" { return tx.Delete(&session).Error } diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index c5f1c26ed5..b19a6b3d86 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "time" "github.com/bytedance/sonic" @@ -14,16 +15,21 @@ import ( "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/logstore" - "github.com/maximhq/bifrost/framework/migrator" "github.com/maximhq/bifrost/framework/vectorstore" "gorm.io/gorm" "gorm.io/gorm/clause" ) // RDBConfigStore represents a configuration store that uses a relational database. +// +// The runtime *gorm.DB is held behind an atomic.Pointer so RefreshConnectionPool +// can swap it out without tearing callers down. migrateOnFreshFn and refreshPoolFn +// are backend-specific hooks installed by the constructor (postgres vs sqlite). type RDBConfigStore struct { - db *gorm.DB - logger schemas.Logger + db atomic.Pointer[gorm.DB] + logger schemas.Logger + migrateOnFreshFn func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error + refreshPoolFn func(ctx context.Context) error } // getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil. @@ -156,7 +162,7 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC ConfigHash: config.ConfigHash, } // Delete existing client config and create new one in a transaction - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableClientConfig{}).Error; err != nil { return err } @@ -166,12 +172,51 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC // Ping checks if the database is reachable. func (s *RDBConfigStore) Ping(ctx context.Context) error { - return s.db.WithContext(ctx).Exec("SELECT 1").Error + return s.DB().WithContext(ctx).Exec("SELECT 1").Error } -// DB returns the underlying database connection. +// DB returns the current runtime database connection. The returned pointer is +// only valid for the duration of the caller's operation — after a +// RefreshConnectionPool call, future DB() calls return a fresh *gorm.DB backed +// by a different *sql.DB pool. Callers that issue multiple operations should +// call DB() per operation rather than caching the pointer. func (s *RDBConfigStore) DB() *gorm.DB { - return s.db + return s.db.Load() +} + +// RunMigration opens a throwaway connection against the same +// backing database, invokes fn with it, and closes the connection. Use this +// for DDL that must not leave cached prepared-statement plans on the runtime +// pool. After fn returns, callers should invoke RefreshConnectionPool if the +// migration altered tables the runtime pool has already queried. +// +// For SQLite, the throwaway concept doesn't apply (no server-side plan cache, +// single-writer file lock), so this runs fn against the existing *gorm.DB. +// +// Returns an error if the store was constructed without a migration hook +// wired — e.g. a direct `&RDBConfigStore{}` literal that skipped the +// newPostgresConfigStore / newSqliteConfigStore constructor. An explicit +// error is safer than a silent fallback to the runtime pool: running DDL +// on the runtime pool would reintroduce SQLSTATE 0A000. +func (s *RDBConfigStore) RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + if s.migrateOnFreshFn == nil { + return fmt.Errorf("configstore: migration hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore") + } + return s.migrateOnFreshFn(ctx, fn) +} + +// RefreshConnectionPool closes the runtime pool and opens a fresh one against +// the same configuration. In-flight queries on the old pool complete before +// it closes; subsequent DB() calls return the new pool, whose connections +// carry no cached plans. SQLite is a no-op. +// +// Returns an error if the store was constructed without a refresh hook wired +// (same rationale as RunMigration). +func (s *RDBConfigStore) RefreshConnectionPool(ctx context.Context) error { + if s.refreshPoolFn == nil { + return fmt.Errorf("configstore: refresh hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore") + } + return s.refreshPoolFn(ctx) } // parseGormError parses GORM errors to provide user-friendly error messages. @@ -273,7 +318,7 @@ func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tabl // GetFrameworkConfig retrieves the framework configuration from the database. func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error) { var dbConfig tables.TableFrameworkConfig - if err := s.db.WithContext(ctx).First(&dbConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -285,7 +330,7 @@ func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableF // GetClientConfig retrieves the client configuration from the database. func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, error) { var dbConfig tables.TableClientConfig - if err := s.db.WithContext(ctx).First(&dbConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -334,7 +379,7 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for providerName, providerConfig := range providers { dbProvider := tables.TableProvider{ @@ -497,7 +542,7 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Find the existing provider var dbProvider tables.TableProvider @@ -648,7 +693,7 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Create a deep copy of the config to avoid modifying the original configCopy, err := deepCopy(config) @@ -748,7 +793,7 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Find the existing provider var dbProvider tables.TableProvider @@ -790,7 +835,7 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo // GetProvidersConfig retrieves the provider configuration from the database. func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) { var dbProviders []tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { return nil, err } if len(dbProviders) == 0 { @@ -827,7 +872,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo // GetProviderConfig retrieves the provider configuration from the database. func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error) { var dbProvider tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -857,7 +902,7 @@ func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas // GetProviderKeys retrieves all keys for a provider ordered by creation time. func (s *RDBConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { var dbKeys []tables.TableKey - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Table("config_providers"). Select("config_keys.*"). Joins("LEFT JOIN config_keys ON config_keys.provider_id = config_providers.id"). @@ -906,7 +951,7 @@ func (s *RDBConfigStore) getProviderKeyByName(ctx context.Context, txDB *gorm.DB // GetProviderKey retrieves a single key for a provider. func (s *RDBConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { - dbKey, err := s.getProviderKeyByName(ctx, s.db, provider, keyID) + dbKey, err := s.getProviderKeyByName(ctx, s.DB(), provider, keyID) if err != nil { return nil, err } @@ -921,7 +966,7 @@ func (s *RDBConfigStore) CreateProviderKey(ctx context.Context, provider schemas if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } var dbProvider tables.TableProvider if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { @@ -946,7 +991,7 @@ func (s *RDBConfigStore) UpdateProviderKey(ctx context.Context, provider schemas if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } existingKey, err := s.getProviderKeyByName(ctx, txDB, provider, keyID) @@ -982,7 +1027,7 @@ func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } providerIDSubquery := txDB.Model(&tables.TableProvider{}). @@ -1005,7 +1050,7 @@ func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas // GetProviders retrieves all providers from the database with their governance relationships. func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) { var providers []tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil { return nil, err } return providers, nil @@ -1014,7 +1059,7 @@ func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvid // GetProvider retrieves a provider by name from the database with governance relationships. func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) { var providerInfo tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1026,7 +1071,7 @@ func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.Model // GetProviderByName retrieves a provider by name from the database with governance relationships. func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) { var provider tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1041,7 +1086,7 @@ func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*t func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, description string) error { // Update key-level status (for keyed providers) if keyID != "" { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Model(&tables.TableKey{}). Where("key_id = ?", keyID). Updates(map[string]interface{}{ @@ -1059,7 +1104,7 @@ func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.Mode // Update provider-level status (for keyless providers) if provider != "" { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Model(&tables.TableProvider{}). Where("name = ?", string(provider)). Updates(map[string]interface{}{ @@ -1082,14 +1127,14 @@ func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.Mode func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient // Get all MCP clients - if err := s.db.WithContext(ctx).Find(&dbMCPClients).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&dbMCPClients).Error; err != nil { return nil, err } if len(dbMCPClients) == 0 { return nil, nil } var clientConfig tables.TableClientConfig - if err := s.db.WithContext(ctx).First(&clientConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&clientConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return MCP config with default ToolManagerConfig if no client config exists // This will never happen, but just in case. @@ -1163,7 +1208,7 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, // GetMCPClientsPaginated retrieves MCP clients with pagination and optional search. func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableMCPClient{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableMCPClient{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -1202,7 +1247,7 @@ func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPC // GetMCPClientByID retrieves an MCP client by ID from the database. func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient - if err := s.db.WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1214,7 +1259,7 @@ func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tabl // GetMCPClientByName retrieves an MCP client by name from the database. func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient - if err := s.db.WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1225,7 +1270,7 @@ func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (* // CreateMCPClientConfig creates a new MCP client configuration in the database. func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Check if a client with the same name already exists if _, err := s.GetMCPClientByName(ctx, clientConfig.Name); err == nil { return fmt.Errorf("MCP client with name '%s' already exists", clientConfig.Name) @@ -1262,7 +1307,7 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig // UpdateMCPClientConfig updates an existing MCP client configuration in the database. func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil { @@ -1376,7 +1421,7 @@ func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, cli if err != nil { return fmt.Errorf("failed to marshal tool name mapping: %w", err) } - return s.db.WithContext(ctx). + return s.DB().WithContext(ctx). Model(&tables.TableMCPClient{}). Where("client_id = ?", clientID). Updates(map[string]interface{}{ @@ -1388,7 +1433,7 @@ func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, cli // DeleteMCPClientConfig deletes an MCP client configuration from the database. func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil { @@ -1411,7 +1456,7 @@ func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) e // GetVectorStoreConfig retrieves the vector store configuration from the database. func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error) { var vectorStoreTableConfig tables.TableVectorStoreConfig - if err := s.db.WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return default cache configuration return nil, nil @@ -1427,7 +1472,7 @@ func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore // UpdateVectorStoreConfig updates the vector store configuration in the database. func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Delete existing cache config if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableVectorStoreConfig{}).Error; err != nil { return err @@ -1449,7 +1494,7 @@ func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *ve // GetLogsStoreConfig retrieves the logs store configuration from the database. func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error) { var dbConfig tables.TableLogStoreConfig - if err := s.db.WithContext(ctx).First(&dbConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -1467,7 +1512,7 @@ func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Conf // UpdateLogsStoreConfig updates the logs store configuration in the database. func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableLogStoreConfig{}).Error; err != nil { return err } @@ -1487,7 +1532,7 @@ func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logs // GetConfig retrieves a specific config from the database. func (s *RDBConfigStore) GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error) { var config tables.TableGovernanceConfig - if err := s.db.WithContext(ctx).First(&config, "key = ?", key).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&config, "key = ?", key).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1502,7 +1547,7 @@ func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableG if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Save(config).Error } @@ -1510,7 +1555,7 @@ func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableG // GetModelPrices retrieves all model pricing records from the database. func (s *RDBConfigStore) GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error) { var modelPrices []tables.TableModelPricing - if err := s.db.WithContext(ctx).Find(&modelPrices).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&modelPrices).Error; err != nil { return nil, err } return modelPrices, nil @@ -1524,7 +1569,7 @@ func (s *RDBConfigStore) UpsertModelPrices(ctx context.Context, pricing *tables. if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } db := txDB.WithContext(ctx) @@ -1543,14 +1588,14 @@ func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error } func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) { var overrides []tables.TablePricingOverride - q := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) + q := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{}) if filters.ScopeKind != nil { q = q.Where("scope_kind = ?", *filters.ScopeKind) } @@ -1570,7 +1615,7 @@ func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters Pricin } func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -1620,7 +1665,7 @@ func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, param func (s *RDBConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) { var override tables.TablePricingOverride - if err := s.db.WithContext(ctx).First(&override, "id = ?", id).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&override, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1634,7 +1679,7 @@ func (s *RDBConfigStore) CreatePricingOverride(ctx context.Context, override *ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(override).Error; err != nil { return s.parseGormError(err) @@ -1647,7 +1692,7 @@ func (s *RDBConfigStore) UpdatePricingOverride(ctx context.Context, override *ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(override).Error; err != nil { return s.parseGormError(err) @@ -1660,7 +1705,7 @@ func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, t if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } res := txDB.WithContext(ctx).Delete(&tables.TablePricingOverride{}, "id = ?", id) if res.Error != nil { @@ -1677,7 +1722,7 @@ func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, t // GetModelParameters returns all stored model parameter rows. func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error) { var rows []tables.TableModelParameters - if err := s.db.WithContext(ctx).Find(&rows).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&rows).Error; err != nil { return nil, err } return rows, nil @@ -1686,7 +1731,7 @@ func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.Table // GetModelParametersByModel retrieves model parameters for a specific model. func (s *RDBConfigStore) GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error) { var params tables.TableModelParameters - if err := s.db.WithContext(ctx).Where("model = ?", model).First(¶ms).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("model = ?", model).First(¶ms).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1703,7 +1748,7 @@ func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tabl if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } db := txDB.WithContext(ctx) @@ -1720,7 +1765,7 @@ func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tabl func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error) { var plugins []*tables.TablePlugin - if err := s.db.WithContext(ctx).Find(&plugins).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&plugins).Error; err != nil { return nil, err } return plugins, nil @@ -1728,7 +1773,7 @@ func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, func (s *RDBConfigStore) GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error) { var plugin tables.TablePlugin - if err := s.db.WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1743,7 +1788,7 @@ func (s *RDBConfigStore) CreatePlugin(ctx context.Context, plugin *tables.TableP if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { @@ -1763,7 +1808,7 @@ func (s *RDBConfigStore) UpsertPlugin(ctx context.Context, plugin *tables.TableP if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { @@ -1802,7 +1847,7 @@ func (s *RDBConfigStore) UpdatePlugin(ctx context.Context, plugin *tables.TableP txDB = tx[0] localTx = false } else { - txDB = s.db.Begin() + txDB = s.DB().Begin() localTx = true } // Mark plugin as custom if path is not empty @@ -1835,7 +1880,7 @@ func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*g if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", name).Error } @@ -1847,12 +1892,12 @@ func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []strin var virtualKeys []tables.TableVirtualKey if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error + err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error + err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error if err != nil { return nil, err } @@ -1903,7 +1948,7 @@ func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirt var virtualKeys []tables.TableVirtualKey // Preload all relationships for complete information - if err := preloadVirtualKeyBaseRelations(s.db.WithContext(ctx)). + if err := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)). Order("created_at ASC"). Find(&virtualKeys).Error; err != nil { return nil, err @@ -1914,7 +1959,7 @@ func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirt // GetVirtualKeysPaginated retrieves virtual keys with pagination, filtering, and search support. func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error) { // Build base query with filters - baseQuery := s.db.WithContext(ctx).Model(&tables.TableVirtualKey{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableVirtualKey{}) // Virtual keys are either customer-scoped or team-scoped, never both. // When both filters are provided, use OR to match keys belonging to either. @@ -1998,7 +2043,7 @@ func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params Vir // GetVirtualKey retrieves a virtual key from the database. func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) { var virtualKey tables.TableVirtualKey - if err := preloadVirtualKeyDetailRelations(s.db.WithContext(ctx)). + if err := preloadVirtualKeyDetailRelations(s.DB().WithContext(ctx)). First(&virtualKey, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound @@ -2012,7 +2057,7 @@ func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables. func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) { valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey - query := preloadVirtualKeyBaseRelations(s.db.WithContext(ctx)) + query := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)) // Use hash-based lookup if hash column is populated, fall back to plaintext for backward compat if err := query.Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2035,7 +2080,7 @@ func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) func (s *RDBConfigStore) GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) { valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey - baseQuery := s.db.WithContext(ctx).Preload("Budgets").Preload("RateLimit") + baseQuery := s.DB().WithContext(ctx).Preload("Budgets").Preload("RateLimit") if err := baseQuery.Session(&gorm.Session{}).Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fallback: try plaintext lookup for rows not yet migrated @@ -2058,7 +2103,7 @@ func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *table if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil { return s.parseGormError(err) @@ -2072,7 +2117,7 @@ func (s *RDBConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *table if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Check if record exists by ID or Name @@ -2106,7 +2151,7 @@ func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tabl return []tables.TableKey{}, nil } var keys []tables.TableKey - if err := s.db.WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil { return nil, err } return keys, nil @@ -2115,7 +2160,7 @@ func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tabl // GetKeysByProvider retrieves all keys for a specific provider func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error) { var keys []tables.TableKey - if err := s.db.WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil { return nil, err } return keys, nil @@ -2125,12 +2170,12 @@ func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) { var keys []tables.TableKey if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error + err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error + err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error if err != nil { return nil, err } @@ -2158,7 +2203,7 @@ func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ( // DeleteVirtualKey deletes a virtual key from the database. func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error { - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var virtualKey tables.TableVirtualKey if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2243,7 +2288,7 @@ func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error // GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database. func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error) { var virtualKey tables.TableVirtualKey - if err := s.db.WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []tables.TableVirtualKeyProviderConfig{}, nil } @@ -2253,7 +2298,7 @@ func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtu return nil, nil } var providerConfigs []tables.TableVirtualKeyProviderConfig - if err := s.db.WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil { return nil, err } return providerConfigs, nil @@ -2265,7 +2310,7 @@ func (s *RDBConfigStore) CreateVirtualKeyProviderConfig(ctx context.Context, vir if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Store keys before create keysToAssociate := virtualKeyProviderConfig.Keys @@ -2336,7 +2381,7 @@ func (s *RDBConfigStore) UpdateVirtualKeyProviderConfig(ctx context.Context, vir if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Store keys before save @@ -2411,7 +2456,7 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // First fetch the provider config to get budget and rate limit IDs var providerConfig tables.TableVirtualKeyProviderConfig @@ -2443,7 +2488,7 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id // GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database. func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error) { var virtualKey tables.TableVirtualKey - if err := s.db.WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []tables.TableVirtualKeyMCPConfig{}, nil } @@ -2453,7 +2498,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKey return nil, nil } var mcpConfigs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { return nil, err } return mcpConfigs, nil @@ -2462,7 +2507,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKey // GetVirtualKeyMCPConfigsByMCPClientID retrieves all VK MCP configs for a given MCP client. func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) { var configs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil { return nil, err } return configs, nil @@ -2474,7 +2519,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Conte return nil, nil } var configs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil { return nil, err } return configs, nil @@ -2487,7 +2532,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context return nil, nil } var configs []tables.TableVirtualKeyMCPConfig - err := s.db.WithContext(ctx). + err := s.DB().WithContext(ctx). Preload("MCPClient"). Joins("JOIN config_mcp_clients ON config_mcp_clients.id = governance_virtual_key_mcp_configs.mcp_client_id"). Where("config_mcp_clients.client_id IN ?", clientIDs). @@ -2504,7 +2549,7 @@ func (s *RDBConfigStore) CreateVirtualKeyMCPConfig(ctx context.Context, virtualK if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(virtualKeyMCPConfig).Error; err != nil { return s.parseGormError(err) @@ -2518,7 +2563,7 @@ func (s *RDBConfigStore) UpdateVirtualKeyMCPConfig(ctx context.Context, virtualK if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(virtualKeyMCPConfig).Error; err != nil { return s.parseGormError(err) @@ -2532,7 +2577,7 @@ func (s *RDBConfigStore) DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "id = ?", id).Error } @@ -2542,7 +2587,7 @@ const teamSelectWithVKCount = "governance_teams.*, (SELECT COUNT(*) FROM governa // GetTeams retrieves all teams from the database. func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error) { // Preload relationships for complete information - query := s.db.WithContext(ctx). + query := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Preload("Customer").Preload("Budget").Preload("RateLimit") // Optional filtering by customer @@ -2558,7 +2603,7 @@ func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tab // GetTeamsPaginated retrieves teams with pagination, filtering, and search support. func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableTeam{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableTeam{}) if params.CustomerID != "" { baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID) @@ -2600,7 +2645,7 @@ func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQuer // GetTeam retrieves a specific team from the database. func (s *RDBConfigStore) GetTeam(ctx context.Context, id string) (*tables.TableTeam, error) { var team tables.TableTeam - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Preload("Customer").Preload("Budget").Preload("RateLimit"). First(&team, "id = ?", id).Error; err != nil { @@ -2618,7 +2663,7 @@ func (s *RDBConfigStore) CreateTeam(ctx context.Context, team *tables.TableTeam, if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(team).Error; err != nil { return s.parseGormError(err) @@ -2632,7 +2677,7 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(team).Error; err != nil { return s.parseGormError(err) @@ -2642,7 +2687,7 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, // DeleteTeam deletes a team from the database. func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var team tables.TableTeam if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&team, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2689,7 +2734,7 @@ func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { // GetCustomers retrieves all customers from the database. func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) { var customers []tables.TableCustomer - if err := preloadCustomerRelations(s.db.WithContext(ctx), ""). + if err := preloadCustomerRelations(s.DB().WithContext(ctx), ""). Order("created_at ASC"). Find(&customers).Error; err != nil { return nil, err @@ -2699,7 +2744,7 @@ func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustom // GetCustomersPaginated retrieves customers with pagination and optional search filtering. func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableCustomer{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableCustomer{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) @@ -2731,7 +2776,7 @@ func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params Custo // GetCustomer retrieves a specific customer from the database. func (s *RDBConfigStore) GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) { var customer tables.TableCustomer - if err := preloadCustomerRelations(s.db.WithContext(ctx), ""). + if err := preloadCustomerRelations(s.DB().WithContext(ctx), ""). First(&customer, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound @@ -2747,7 +2792,7 @@ func (s *RDBConfigStore) CreateCustomer(ctx context.Context, customer *tables.Ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(customer).Error; err != nil { return s.parseGormError(err) @@ -2761,7 +2806,7 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(customer).Error; err != nil { return s.parseGormError(err) @@ -2771,7 +2816,7 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta // DeleteCustomer deletes a customer from the database. func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var customer tables.TableCustomer if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&customer, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2822,7 +2867,7 @@ func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { // GetRateLimits retrieves all rate limits from the database. func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { var rateLimits []tables.TableRateLimit - if err := s.db.WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil { + if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil { return nil, err } return rateLimits, nil @@ -2834,7 +2879,7 @@ func (s *RDBConfigStore) GetRateLimit(ctx context.Context, id string, tx ...*gor if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } var rateLimit tables.TableRateLimit if err := txDB.WithContext(ctx).First(&rateLimit, "id = ?", id).Error; err != nil { @@ -2852,7 +2897,7 @@ func (s *RDBConfigStore) CreateRateLimit(ctx context.Context, rateLimit *tables. if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(rateLimit).Error; err != nil { return s.parseGormError(err) @@ -2866,7 +2911,7 @@ func (s *RDBConfigStore) UpdateRateLimit(ctx context.Context, rateLimit *tables. if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(rateLimit).Error; err != nil { return s.parseGormError(err) @@ -2880,7 +2925,7 @@ func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tab if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for _, rl := range rateLimits { if err := txDB.WithContext(ctx).Save(rl).Error; err != nil { @@ -2896,7 +2941,7 @@ func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...* if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", id).Error; err != nil { return s.parseGormError(err) @@ -2907,7 +2952,7 @@ func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...* // GetBudgets retrieves all budgets from the database. func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) { var budgets []tables.TableBudget - if err := s.db.WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil { + if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil { return nil, err } return budgets, nil @@ -2919,7 +2964,7 @@ func (s *RDBConfigStore) GetBudget(ctx context.Context, id string, tx ...*gorm.D if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } var budget tables.TableBudget if err := txDB.WithContext(ctx).First(&budget, "id = ?", id).Error; err != nil { @@ -2937,7 +2982,7 @@ func (s *RDBConfigStore) CreateBudget(ctx context.Context, budget *tables.TableB if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(budget).Error; err != nil { return s.parseGormError(err) @@ -2951,7 +2996,7 @@ func (s *RDBConfigStore) UpdateBudgets(ctx context.Context, budgets []*tables.Ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for _, b := range budgets { if err := txDB.WithContext(ctx).Save(b).Error; err != nil { @@ -2967,7 +3012,7 @@ func (s *RDBConfigStore) UpdateBudget(ctx context.Context, budget *tables.TableB if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(budget).Error; err != nil { return s.parseGormError(err) @@ -2981,7 +3026,7 @@ func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gor if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", id).Error; err != nil { return s.parseGormError(err) @@ -2992,7 +3037,7 @@ func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gor // UpdateBudgetUsage updates only the current_usage field of a budget. // Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Session(&gorm.Session{SkipHooks: true}). Model(&tables.TableBudget{}). Where("id = ?", id). @@ -3009,7 +3054,7 @@ func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, curre // UpdateRateLimitUsage updates only the usage fields of a rate limit. // Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Session(&gorm.Session{SkipHooks: true}). Model(&tables.TableRateLimit{}). Where("id = ?", id). @@ -3029,7 +3074,7 @@ func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, to // loadRoutingRulesOrdered loads routing rules with Targets preloaded, using consistent ordering: // rules by priority ASC, created_at DESC, id ASC; targets by weight DESC for deterministic ordering. func (s *RDBConfigStore) loadRoutingRulesOrdered(ctx context.Context, dest *[]tables.TableRoutingRule, scopes ...func(*gorm.DB) *gorm.DB) error { - q := s.db.WithContext(ctx). + q := s.DB().WithContext(ctx). Preload("Targets", func(db *gorm.DB) *gorm.DB { return db.Order("weight DESC"). Order("COALESCE(provider, '') ASC"). @@ -3054,7 +3099,7 @@ func (s *RDBConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRou // GetRoutingRulesPaginated retrieves routing rules with pagination and optional search filtering. func (s *RDBConfigStore) GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableRoutingRule{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableRoutingRule{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -3135,12 +3180,12 @@ func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []stri var routingRules []tables.TableRoutingRule if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error + err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error + err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error if err != nil { return nil, err } @@ -3150,7 +3195,7 @@ func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []stri // CreateRoutingRule creates a new routing rule in the database. func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error { - database := s.db + database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } @@ -3199,7 +3244,7 @@ func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.Tab // UpdateRoutingRule updates an existing routing rule in the database. // It enforces the same unique-priority-per-scope invariant as CreateRoutingRule. func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error { - database := s.db + database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } @@ -3250,7 +3295,7 @@ func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.Tab // DeleteRoutingRule deletes a routing rule and its targets from the database. func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error { - database := s.db + database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } @@ -3273,7 +3318,7 @@ func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx .. // GetModelConfigs retrieves all model configs from the database. func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) { var modelConfigs []tables.TableModelConfig - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil { return nil, err } return modelConfigs, nil @@ -3281,7 +3326,7 @@ func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableMod // GetModelConfigsPaginated retrieves model configs with pagination, filtering, and search support. func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableModelConfig{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableModelConfig{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -3322,7 +3367,7 @@ func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params Mo // GetModelConfig retrieves a specific model config from the database by model name and optional provider. func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) { var modelConfig tables.TableModelConfig - query := s.db.WithContext(ctx).Where("model_name = ?", modelName) + query := s.DB().WithContext(ctx).Where("model_name = ?", modelName) if provider != nil { query = query.Where("provider = ?", *provider) } else { @@ -3340,7 +3385,7 @@ func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, p // GetModelConfigByID retrieves a specific model config from the database by ID. func (s *RDBConfigStore) GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) { var modelConfig tables.TableModelConfig - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -3355,7 +3400,7 @@ func (s *RDBConfigStore) CreateModelConfig(ctx context.Context, modelConfig *tab if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(modelConfig).Error; err != nil { return s.parseGormError(err) @@ -3369,7 +3414,7 @@ func (s *RDBConfigStore) UpdateModelConfig(ctx context.Context, modelConfig *tab if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(modelConfig).Error; err != nil { return s.parseGormError(err) @@ -3383,7 +3428,7 @@ func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs [] if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for _, mc := range modelConfigs { if err := txDB.WithContext(ctx).Save(mc).Error; err != nil { @@ -3395,7 +3440,7 @@ func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs [] // DeleteModelConfig deletes a model config from the database. func (s *RDBConfigStore) DeleteModelConfig(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // First fetch the model config to get budget and rate limit IDs var modelConfig tables.TableModelConfig if err := tx.First(&modelConfig, "id = ?", id).Error; err != nil { @@ -3443,7 +3488,7 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo var pricingOverrides []tables.TablePricingOverride var governanceConfigs []tables.TableGovernanceConfig - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("ProviderConfigs"). Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, key_id, models_json, provider") @@ -3451,34 +3496,34 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo Find(&virtualKeys).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Find(&teams).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&customers).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&customers).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&budgets).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&budgets).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&rateLimits).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&rateLimits).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&modelConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&modelConfigs).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&providers).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&providers).Error; err != nil { return nil, err } if err := s.loadRoutingRulesOrdered(ctx, &routingRules); err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&pricingOverrides).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&pricingOverrides).Error; err != nil { return nil, err } // Fetching governance config for username and password - if err := s.db.WithContext(ctx).Find(&governanceConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&governanceConfigs).Error; err != nil { return nil, err } // Check if any config is present @@ -3533,22 +3578,22 @@ func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) var password *string var isEnabled bool var disableAuthOnInference bool - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } @@ -3566,7 +3611,7 @@ func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) // UpdateAuthConfig updates the auth configuration in the database. func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfig) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Save(&tables.TableGovernanceConfig{ Key: tables.ConfigAdminUsernameKey, Value: config.AdminUserName.GetValue(), @@ -3598,7 +3643,7 @@ func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfi // GetProxyConfig retrieves the proxy configuration from the database. func (s *RDBConfigStore) GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error) { var configEntry tables.TableGovernanceConfig - if err := s.db.WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -3645,7 +3690,7 @@ func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.G if err != nil { return fmt.Errorf("failed to marshal proxy config: %w", err) } - return s.db.WithContext(ctx).Save(&tables.TableGovernanceConfig{ + return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigProxyKey, Value: string(configJSON), }).Error @@ -3654,7 +3699,7 @@ func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.G // GetRestartRequiredConfig retrieves the restart required configuration from the database. func (s *RDBConfigStore) GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error) { var configEntry tables.TableGovernanceConfig - if err := s.db.WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -3676,7 +3721,7 @@ func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *t if err != nil { return fmt.Errorf("failed to marshal restart required config: %w", err) } - return s.db.WithContext(ctx).Save(&tables.TableGovernanceConfig{ + return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigRestartRequiredKey, Value: string(configJSON), }).Error @@ -3684,7 +3729,7 @@ func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *t // ClearRestartRequiredConfig clears the restart required configuration in the database. func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error { - return s.db.WithContext(ctx).Save(&tables.TableGovernanceConfig{ + return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigRestartRequiredKey, Value: `{"required":false,"reason":""}`, }).Error @@ -3694,11 +3739,11 @@ func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error { func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables.SessionsTable, error) { var session tables.SessionsTable tokenHash := encrypt.HashSHA256(token) - err := s.db.WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error + err := s.DB().WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fall back to plaintext lookup for backward compatibility - if err := s.db.WithContext(ctx).First(&session, "token = ?", token).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&session, "token = ?", token).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -3713,31 +3758,31 @@ func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables. // CreateSession creates a new session in the database. func (s *RDBConfigStore) CreateSession(ctx context.Context, session *tables.SessionsTable) error { - return s.db.WithContext(ctx).Create(session).Error + return s.DB().WithContext(ctx).Create(session).Error } // DeleteSession deletes a session from the database. func (s *RDBConfigStore) DeleteSession(ctx context.Context, token string) error { tokenHash := encrypt.HashSHA256(token) - result := s.db.WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash) + result := s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { // Fall back to plaintext lookup for backward compatibility - return s.db.WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error + return s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error } return nil } // FlushSessions flushes all sessions from the database. func (s *RDBConfigStore) FlushSessions(ctx context.Context) error { - return s.db.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error + return s.DB().WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error } // ExecuteTransaction executes a transaction. func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { - return s.db.WithContext(ctx).Transaction(fn) + return s.DB().WithContext(ctx).Transaction(fn) } // RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound @@ -3769,12 +3814,12 @@ func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx contex // doesTableExist checks if a table exists in the database. func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool { - return s.db.WithContext(ctx).Migrator().HasTable(tableName) + return s.DB().WithContext(ctx).Migrator().HasTable(tableName) } // removeNullKeys removes null keys from the database. func (s *RDBConfigStore) removeNullKeys(ctx context.Context) error { - return s.db.WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error + return s.DB().WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error } // removeDuplicateKeysAndNullKeys removes duplicate keys based on key_id and value combination @@ -3793,7 +3838,7 @@ func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) err s.logger.Debug("deleting duplicate keys from the database") // Find and delete duplicate keys, keeping only the one with the smallest ID // This query deletes all records except the one with the minimum ID for each (key_id, value) pair - result := s.db.WithContext(ctx).Exec(` + result := s.DB().WithContext(ctx).Exec(` DELETE FROM config_keys WHERE id NOT IN ( SELECT MIN(id) @@ -3809,18 +3854,9 @@ func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) err return nil } -// RunMigration runs a migration. -func (s *RDBConfigStore) RunMigration(ctx context.Context, migration *migrator.Migration) error { - if migration == nil { - return fmt.Errorf("migration cannot be nil") - } - m := migrator.New(s.db, migrator.DefaultOptions, []*migrator.Migration{migration}) - return m.Migrate() -} - // Close closes the SQLite config store. func (s *RDBConfigStore) Close(ctx context.Context) error { - sqlDB, err := s.db.DB() + sqlDB, err := s.DB().DB() if err != nil { return err } @@ -3836,7 +3872,7 @@ func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableD } // Use GORM clause-based insert for dialect-appropriate SQL - result := s.db.WithContext(ctx).Clauses( + result := s.DB().WithContext(ctx).Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "lock_key"}}, DoNothing: true, @@ -3854,7 +3890,7 @@ func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableD // GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist. func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error) { var lock tables.TableDistributedLock - result := s.db.WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock) + result := s.DB().WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -3869,7 +3905,7 @@ func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.T // UpdateLockExpiry updates the expiration time for an existing lock. // Only succeeds if the holder ID matches the current lock holder. func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error { - result := s.db.WithContext(ctx).Model(&tables.TableDistributedLock{}). + result := s.DB().WithContext(ctx).Model(&tables.TableDistributedLock{}). Where("lock_key = ? AND holder_id = ? AND expires_at > ?", lockKey, holderID, time.Now().UTC()). Update("expires_at", expiresAt) @@ -3887,7 +3923,7 @@ func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID // ReleaseLock deletes a lock if the holder ID matches. // Returns true if the lock was released, false if it wasn't held by the given holder. func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error) { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("lock_key = ? AND holder_id = ?", lockKey, holderID). Delete(&tables.TableDistributedLock{}) @@ -3901,7 +3937,7 @@ func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID stri // CleanupExpiredLocks removes all locks that have expired. // Returns the number of locks cleaned up. func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("expires_at < ?", time.Now().UTC()). Delete(&tables.TableDistributedLock{}) @@ -3915,7 +3951,7 @@ func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) // CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired. // Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired. func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error) { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("lock_key = ? AND expires_at < ?", lockKey, time.Now().UTC()). Delete(&tables.TableDistributedLock{}) @@ -3931,7 +3967,7 @@ func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey st // GetOauthConfigByID retrieves an OAuth config by its ID func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig - result := s.db.WithContext(ctx).Where("id = ?", id).First(&config) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -3945,7 +3981,7 @@ func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*ta // State is unique per OAuth flow (used for CSRF protection on callback) func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig - result := s.db.WithContext(ctx).Where("state = ?", state).First(&config) + result := s.DB().WithContext(ctx).Where("state = ?", state).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -3958,7 +3994,7 @@ func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string // GetOauthTokenByID retrieves an OAuth token by its ID func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error) { var token tables.TableOauthToken - result := s.db.WithContext(ctx).Where("id = ?", id).First(&token) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -3970,7 +4006,7 @@ func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tab // CreateOauthConfig creates a new OAuth config func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error { - result := s.db.WithContext(ctx).Create(config) + result := s.DB().WithContext(ctx).Create(config) if result.Error != nil { return fmt.Errorf("failed to create oauth config: %w", result.Error) } @@ -3979,7 +4015,7 @@ func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.T // CreateOauthToken creates a new OAuth token func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error { - result := s.db.WithContext(ctx).Create(token) + result := s.DB().WithContext(ctx).Create(token) if result.Error != nil { return fmt.Errorf("failed to create oauth token: %w", result.Error) } @@ -3988,7 +4024,7 @@ func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.Tab // UpdateOauthConfig updates an existing OAuth config func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error { - result := s.db.WithContext(ctx).Save(config) + result := s.DB().WithContext(ctx).Save(config) if result.Error != nil { return fmt.Errorf("failed to update oauth config: %w", result.Error) } @@ -3997,7 +4033,7 @@ func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.T // UpdateOauthToken updates an existing OAuth token func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error { - result := s.db.WithContext(ctx).Save(token) + result := s.DB().WithContext(ctx).Save(token) if result.Error != nil { return fmt.Errorf("failed to update oauth token: %w", result.Error) } @@ -4006,7 +4042,7 @@ func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.Tab // DeleteOauthToken deletes an OAuth token by its ID func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth token: %w", result.Error) } @@ -4016,7 +4052,7 @@ func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error // GetExpiringOauthTokens retrieves tokens that are expiring before the given time func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error) { var tokens []*tables.TableOauthToken - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("expires_at < ?", before). Find(&tokens) if result.Error != nil { @@ -4028,7 +4064,7 @@ func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time // GetOauthConfigByTokenID retrieves an OAuth config that references a specific token func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig - result := s.db.WithContext(ctx).Where("token_id = ?", tokenID).First(&config) + result := s.DB().WithContext(ctx).Where("token_id = ?", tokenID).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4043,7 +4079,7 @@ func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID st // GetOauthUserSessionByID retrieves a per-user OAuth session by its ID func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession - result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4056,7 +4092,7 @@ func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) // GetOauthUserSessionByState retrieves a per-user OAuth session by its state token func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession - result := s.db.WithContext(ctx).Where("state = ?", state).First(&session) + result := s.DB().WithContext(ctx).Where("state = ?", state).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4070,7 +4106,7 @@ func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state s // Returns nil if the session doesn't exist or has already been claimed by another request. func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession - result := s.db.WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session) + result := s.DB().WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4078,7 +4114,7 @@ func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state return nil, fmt.Errorf("failed to claim oauth user session by state: %w", result.Error) } // Atomically transition from "pending" to "claiming" to prevent concurrent claims - updateResult := s.db.WithContext(ctx).Model(&tables.TableOauthUserSession{}). + updateResult := s.DB().WithContext(ctx).Model(&tables.TableOauthUserSession{}). Where("id = ? AND status = ?", session.ID, "pending"). Update("status", "claiming") if updateResult.Error != nil { @@ -4095,7 +4131,7 @@ func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession tokenHash := encrypt.HashSHA256(sessionToken) - result := s.db.WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session) + result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4107,7 +4143,7 @@ func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, // CreateOauthUserSession creates a new per-user OAuth session func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { - result := s.db.WithContext(ctx).Create(session) + result := s.DB().WithContext(ctx).Create(session) if result.Error != nil { return fmt.Errorf("failed to create oauth user session: %w", result.Error) } @@ -4116,7 +4152,7 @@ func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *ta // UpdateOauthUserSession updates an existing per-user OAuth session func (s *RDBConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { - result := s.db.WithContext(ctx).Save(session) + result := s.DB().WithContext(ctx).Save(session) if result.Error != nil { return fmt.Errorf("failed to update oauth user session: %w", result.Error) } @@ -4133,11 +4169,11 @@ func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtua var result *gorm.DB if userID != "" { - result = s.db.WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token) + result = s.DB().WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token) } else if virtualKeyID != "" { - result = s.db.WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token) + result = s.DB().WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token) } else if sessionToken != "" { - result = s.db.WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token) + result = s.DB().WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token) } else { return nil, nil } @@ -4154,7 +4190,7 @@ func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtua func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { var token tables.TableOauthUserToken tokenHash := encrypt.HashSHA256(sessionToken) - result := s.db.WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token) + result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4170,7 +4206,7 @@ func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, se func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { // Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing // duplicate tokens when concurrent requests race on the same identity+client pair. - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if token.UserID != nil && *token.UserID != "" { var existing tables.TableOauthUserToken err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error @@ -4202,7 +4238,7 @@ func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables // UpdateOauthUserToken updates an existing per-user OAuth token func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { - result := s.db.WithContext(ctx).Save(token) + result := s.DB().WithContext(ctx).Save(token) if result.Error != nil { return fmt.Errorf("failed to update oauth user token: %w", result.Error) } @@ -4211,7 +4247,7 @@ func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables // DeleteOauthUserToken deletes a per-user OAuth token by its ID func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth user token: %w", result.Error) } @@ -4220,7 +4256,7 @@ func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) er // DeleteOauthUserTokensByMCPClient deletes all per-user OAuth tokens for a specific MCP client func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { - result := s.db.WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{}) + result := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth user tokens for mcp client: %w", result.Error) } @@ -4232,7 +4268,7 @@ func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, m // GetPerUserOAuthClientByClientID retrieves a dynamically registered OAuth client by its client_id. func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { var client tables.TablePerUserOAuthClient - result := s.db.WithContext(ctx).Where("client_id = ?", clientID).First(&client) + result := s.DB().WithContext(ctx).Where("client_id = ?", clientID).First(&client) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4244,7 +4280,7 @@ func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, cl // CreatePerUserOAuthClient creates a new dynamically registered OAuth client. func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { - result := s.db.WithContext(ctx).Create(client) + result := s.DB().WithContext(ctx).Create(client) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth client: %w", result.Error) } @@ -4255,7 +4291,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *t func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession tokenHash := encrypt.HashSHA256(accessToken) - result := s.db.WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { + result := s.DB().WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, value, encryption_status") }).First(&session) if result.Error != nil { @@ -4270,7 +4306,7 @@ func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context // GetPerUserOAuthSessionByID retrieves a Bifrost-issued session by its ID. func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession - result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4282,7 +4318,7 @@ func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id stri // CreatePerUserOAuthSession creates a new Bifrost-issued OAuth session. func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { - result := s.db.WithContext(ctx).Create(session) + result := s.DB().WithContext(ctx).Create(session) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth session: %w", result.Error) } @@ -4291,7 +4327,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session // UpdatePerUserOAuthSession updates a Bifrost-issued OAuth session (e.g., to attach user identity). func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { - result := s.db.WithContext(ctx).Save(session) + result := s.DB().WithContext(ctx).Save(session) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth session: %w", result.Error) } @@ -4300,7 +4336,7 @@ func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session // DeletePerUserOAuthSession deletes a Bifrost-issued OAuth session by ID. func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{}) if result.Error != nil { return fmt.Errorf("failed to delete per-user oauth session: %w", result.Error) } @@ -4311,7 +4347,7 @@ func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id strin func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { var codeRecord tables.TablePerUserOAuthCode codeHash := encrypt.HashSHA256(code) - result := s.db.WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord) + result := s.DB().WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4323,7 +4359,7 @@ func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code str // CreatePerUserOAuthCode creates a new authorization code record. func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { - result := s.db.WithContext(ctx).Create(code) + result := s.DB().WithContext(ctx).Create(code) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth code: %w", result.Error) } @@ -4335,7 +4371,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *table func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { codeHash := encrypt.HashSHA256(code) var codeRecord tables.TablePerUserOAuthCode - result := s.db.WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord) + result := s.DB().WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4343,7 +4379,7 @@ func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) return nil, fmt.Errorf("failed to find per-user oauth code: %w", result.Error) } // Atomically mark as used - updateResult := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}). + updateResult := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}). Where("id = ? AND used = ?", codeRecord.ID, false). Update("used", true) if updateResult.Error != nil { @@ -4358,7 +4394,7 @@ func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) // UpdatePerUserOAuthCode updates an authorization code record (e.g., marking as used). func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { - result := s.db.WithContext(ctx).Save(code) + result := s.DB().WithContext(ctx).Save(code) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth code: %w", result.Error) } @@ -4370,7 +4406,7 @@ func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *table // GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID. func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { var flow tables.TablePerUserOAuthPendingFlow - result := s.db.WithContext(ctx).Where("id = ?", id).First(&flow) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&flow) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4382,7 +4418,7 @@ func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id stri // CreatePerUserOAuthPendingFlow persists a new pending consent flow. func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { - result := s.db.WithContext(ctx).Create(flow) + result := s.DB().WithContext(ctx).Create(flow) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error) } @@ -4391,7 +4427,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow // UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step). func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { - result := s.db.WithContext(ctx).Save(flow) + result := s.DB().WithContext(ctx).Save(flow) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error) } @@ -4400,7 +4436,7 @@ func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow // DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted. func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error) } @@ -4409,14 +4445,14 @@ func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id s func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { now := time.Now().UTC() - result := s.db.WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) + result := s.DB().WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) } if result.RowsAffected == 0 { // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). var count int64 - if err := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { + if err := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) } if count > 0 { @@ -4430,7 +4466,7 @@ func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id // and creates the authorization code in a single transaction. func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { var rowsAffected int64 - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. Consume the pending flow (atomic idempotency guard). // Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check. now := time.Now().UTC() @@ -4479,8 +4515,8 @@ func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Contex // linked to this gateway session ID. This supports per-service proxy tokens // (e.g. "flow::") where each MCP service gets its own hash. var tokens []tables.TableOauthUserToken - subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) - result := s.db.WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) + subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.DB().WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) if result.Error != nil { return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error) } @@ -4510,8 +4546,8 @@ func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.C // Update all tokens whose session_token_hash matches any upstream session // linked to this gateway session ID. - subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) - result := s.db.WithContext(ctx).Model(&tables.TableOauthUserToken{}). + subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.DB().WithContext(ctx).Model(&tables.TableOauthUserToken{}). Where("session_token_hash IN (?)", subquery). Updates(updates) if result.Error != nil { diff --git a/framework/configstore/rdb_test.go b/framework/configstore/rdb_test.go index 4877dd02fc..48325f82f2 100644 --- a/framework/configstore/rdb_test.go +++ b/framework/configstore/rdb_test.go @@ -53,10 +53,13 @@ func setupRDBTestStore(t *testing.T) *RDBConfigStore { err = db.SetupJoinTable(&tables.TableVirtualKeyProviderConfig{}, "Keys", &tables.TableVirtualKeyProviderConfigKey{}) require.NoError(t, err, "Failed to setup join table") - return &RDBConfigStore{ - db: db, - logger: nil, + s := &RDBConfigStore{logger: nil} + s.db.Store(db) + s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, s.DB()) } + s.refreshPoolFn = func(ctx context.Context) error { return nil } + return s } // ============================================================================= @@ -718,7 +721,7 @@ func TestCreateVirtualKeyProviderConfig_WithKeys(t *testing.T) { // Load with keys var configWithKeys tables.TableVirtualKeyProviderConfig - err = store.db.Preload("Keys").First(&configWithKeys, "id = ?", configs[0].ID).Error + err = store.DB().Preload("Keys").First(&configWithKeys, "id = ?", configs[0].ID).Error require.NoError(t, err) assert.Len(t, configWithKeys.Keys, 1) } @@ -1203,7 +1206,7 @@ func createTestPromptTree(t *testing.T, store *RDBConfigStore, ctx context.Conte func countRows(t *testing.T, store *RDBConfigStore, model interface{}) int64 { t.Helper() var count int64 - require.NoError(t, store.db.Model(model).Count(&count).Error) + require.NoError(t, store.DB().Model(model).Count(&count).Error) return count } @@ -1389,7 +1392,7 @@ func TestDeletePromptSession(t *testing.T) { // Session messages for that session should be gone var msgCount int64 - require.NoError(t, store.db.Model(&tables.TablePromptSessionMessage{}).Where("session_id = ?", sessionID).Count(&msgCount).Error) + require.NoError(t, store.DB().Model(&tables.TablePromptSessionMessage{}).Where("session_id = ?", sessionID).Count(&msgCount).Error) assert.Equal(t, int64(0), msgCount) }) diff --git a/framework/configstore/sqlite.go b/framework/configstore/sqlite.go index 4c4cbe8594..9482801d08 100644 --- a/framework/configstore/sqlite.go +++ b/framework/configstore/sqlite.go @@ -35,7 +35,16 @@ func newSqliteConfigStore(ctx context.Context, config *SQLiteConfig, logger sche return nil, err } logger.Debug("db opened for configstore") - s := &RDBConfigStore{db: db, logger: logger} + s := &RDBConfigStore{logger: logger} + s.db.Store(db) + // SQLite has no server-side prepared-plan cache, and opening a second + // handle on the same file would contend for the single-writer lock — + // so both hooks operate on the existing *gorm.DB. + s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, s.DB()) + } + s.refreshPoolFn = func(ctx context.Context) error { return nil } + logger.Debug("running migration to remove duplicate keys") // Run migration to remove duplicate keys before AutoMigrate if err := s.removeDuplicateKeysAndNullKeys(ctx); err != nil { diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 3fbb678159..16cedc6b6a 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -9,7 +9,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/logstore" - "github.com/maximhq/bifrost/framework/migrator" "github.com/maximhq/bifrost/framework/vectorstore" "gorm.io/gorm" ) @@ -393,8 +392,25 @@ type ConfigStore interface { // DB returns the underlying database connection. DB() *gorm.DB - // Migration manager - RunMigration(ctx context.Context, migration *migrator.Migration) error + // RunMigration opens a throwaway *gorm.DB against the same + // backing database, invokes fn with it, and closes the connection. Use + // this for DDL (typically downstream-consumer migrations) that must not + // leave cached prepared-statement plans on the runtime pool. + // + // After fn returns successfully, callers should invoke + // RefreshConnectionPool if the migration altered tables the runtime pool + // has already queried — otherwise SQLSTATE 0A000 can surface on reads + // whose cached plans predate the DDL. + // + // For SQLite backends, this is a pass-through that runs fn on the + // existing connection (no server-side plan cache, single-writer lock). + RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error + + // RefreshConnectionPool tears down the runtime pool and opens a fresh + // one against the same configuration. In-flight queries on the old + // pool complete before it closes; subsequent DB() calls return the new + // pool, whose connections carry no cached plans. SQLite is a no-op. + RefreshConnectionPool(ctx context.Context) error // Cleanup Close(ctx context.Context) error diff --git a/framework/configstore/tables/budget.go b/framework/configstore/tables/budget.go index 2d7d397d26..897cfec4d9 100644 --- a/framework/configstore/tables/budget.go +++ b/framework/configstore/tables/budget.go @@ -19,6 +19,8 @@ type TableBudget struct { VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id,omitempty"` ProviderConfigID *uint `gorm:"index" json:"provider_config_id,omitempty"` + CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries + // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` @@ -36,7 +38,6 @@ func (b *TableBudget) BeforeSave(tx *gorm.DB) error { if b.VirtualKeyID != nil && b.ProviderConfigID != nil { return fmt.Errorf("budget cannot belong to both a virtual key and a provider config") } - // Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y") if d, err := ParseDuration(b.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration) diff --git a/framework/configstore/tables/ratelimit.go b/framework/configstore/tables/ratelimit.go index 7147e7b89f..0268b53164 100644 --- a/framework/configstore/tables/ratelimit.go +++ b/framework/configstore/tables/ratelimit.go @@ -23,6 +23,8 @@ type TableRateLimit struct { RequestCurrentUsage int64 `gorm:"default:0" json:"request_current_usage"` // Current request usage RequestLastReset time.Time `gorm:"index" json:"request_last_reset"` // Last time request counter was reset + CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries + // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` diff --git a/framework/configstore/tables/team.go b/framework/configstore/tables/team.go index e96614c600..4beee97ab9 100644 --- a/framework/configstore/tables/team.go +++ b/framework/configstore/tables/team.go @@ -25,14 +25,14 @@ type TableTeam struct { // Computed (not a DB column) — populated via correlated subquery in query layer, hence no migration VirtualKeyCount int64 `gorm:"->;-:migration" json:"virtual_key_count"` - Profile *string `gorm:"type:text" json:"-"` - ParsedProfile map[string]interface{} `gorm:"-" json:"profile"` + Profile *string `gorm:"type:text" json:"-"` + ParsedProfile map[string]any `gorm:"-" json:"profile"` - Config *string `gorm:"type:text" json:"-"` - ParsedConfig map[string]interface{} `gorm:"-" json:"config"` + Config *string `gorm:"type:text" json:"-"` + ParsedConfig map[string]any `gorm:"-" json:"config"` - Claims *string `gorm:"type:text" json:"-"` - ParsedClaims map[string]interface{} `gorm:"-" json:"claims"` + Claims *string `gorm:"type:text" json:"-"` + ParsedClaims map[string]any `gorm:"-" json:"claims"` // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go index fb603202eb..07c2e91bb8 100644 --- a/framework/configstore/tables/virtualkey.go +++ b/framework/configstore/tables/virtualkey.go @@ -182,20 +182,16 @@ func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error { Alias MCPClientName string `json:"mcp_client_name"` // Config file format: MCP client name } - var temp TempMCPConfig if err := json.Unmarshal(data, &temp); err != nil { return err } - // Copy all standard fields *mc = TableVirtualKeyMCPConfig(temp.Alias) - // Capture mcp_client_name for later resolution to MCPClientID if temp.MCPClientName != "" { mc.MCPClientName = temp.MCPClientName } - return nil } @@ -210,10 +206,15 @@ type TableVirtualKey struct { MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"` // Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both) - TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"` - CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` - RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` - CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries + TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"` + CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` + RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` + + // Deprecated + // Calendar aligned is not the property of virtual key but its property of the budget and ratelimit + // So in the migration we will move this to the budget/ratelimit table + // And this won't be referred + CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries // Relationships Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"` diff --git a/framework/logstore/asyncjob.go b/framework/logstore/asyncjob.go index 206c82746f..8923420d2d 100644 --- a/framework/logstore/asyncjob.go +++ b/framework/logstore/asyncjob.go @@ -30,11 +30,11 @@ const ( // AsyncOperation represents a function that can be executed asynchronously. // It returns the response and an optional BifrostError. -type AsyncOperation func(ctx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) +type AsyncOperation func(ctx *schemas.BifrostContext) (any, *schemas.BifrostError) // GovernanceStore is an interface that provides access to the governance store. type GovernanceStore interface { - GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) + GetVirtualKey(ctx context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) } // AsyncJobExecutor manages async job creation and background execution. @@ -66,7 +66,7 @@ func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValu if vkValue == nil { return nil, fmt.Errorf("virtual key is required") } - vk, ok := e.governanceStore.GetVirtualKey(*vkValue) + vk, ok := e.governanceStore.GetVirtualKey(ctx, *vkValue) if !ok { return nil, fmt.Errorf("virtual key not found") } @@ -90,7 +90,7 @@ func (e *AsyncJobExecutor) SubmitJob(bifrostCtx *schemas.BifrostContext, resultT var virtualKeyID *string if virtualKeyValue != nil { - vk, ok := e.governanceStore.GetVirtualKey(*virtualKeyValue) + vk, ok := e.governanceStore.GetVirtualKey(bifrostCtx, *virtualKeyValue) if !ok { return nil, fmt.Errorf("virtual key not found") } @@ -112,7 +112,11 @@ func (e *AsyncJobExecutor) SubmitJob(bifrostCtx *schemas.BifrostContext, resultT return nil, fmt.Errorf("failed to create async job: %w", err) } - go e.executeJob(job.ID, job.ResultTTL, operation, bifrostCtx.GetUserValues()) + var contextValues map[any]any + if bifrostCtx != nil { + contextValues = bifrostCtx.GetUserValues() + } + go e.executeJob(job.ID, job.ResultTTL, operation, contextValues) return job, nil } @@ -135,7 +139,7 @@ func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation Asy now := time.Now().UTC() expiresAt := now.Add(time.Duration(resultTTL) * time.Second) errJSON, _ := sonic.Marshal(&schemas.BifrostError{Error: &schemas.ErrorField{Message: msg}}) - if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{ + if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]any{ "status": schemas.AsyncJobStatusFailed, "status_code": fasthttp.StatusInternalServerError, "error": string(errJSON), @@ -299,8 +303,13 @@ func (c *AsyncJobCleaner) cleanupExpiredJobs(ctx context.Context) { } // getVirtualKeyFromContext extracts the virtual key value from context. -// Returns nil if no VK is present (e.g., direct key mode or no governance). +// Returns nil if no VK is present (e.g., direct key mode or no governance), +// or if the context itself is nil (callers like SubmitJob may be invoked with +// a nil ctx by background paths that don't carry a VK). func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string { + if ctx == nil { + return nil + } vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) if vkValue == "" { return nil diff --git a/framework/logstore/asyncjob_test.go b/framework/logstore/asyncjob_test.go index df71d7befe..c01bdef008 100644 --- a/framework/logstore/asyncjob_test.go +++ b/framework/logstore/asyncjob_test.go @@ -30,7 +30,7 @@ type testGovernanceStore struct { virtualKeys map[string]*configstoreTables.TableVirtualKey } -func (t *testGovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { +func (t *testGovernanceStore) GetVirtualKey(_ context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) { vk, ok := t.virtualKeys[vkValue] return vk, ok } @@ -86,14 +86,10 @@ func waitForJobStatus(t *testing.T, store LogStore, jobID string) *AsyncJob { func TestSubmitJob_PropagatesContextValues(t *testing.T) { executor := newTestAsyncExecutor(t) - // Simulate original request context values - contextValues := map[any]any{ - schemas.BifrostContextKeyVirtualKey: "sk-bf-test", - schemas.BifrostContextKey("x-bf-prom-env"): "production", - schemas.BifrostContextKey("x-bf-eh-custom"): "custom-value", - } - - var capturedCtx *schemas.BifrostContext + capturedCtx := schemas.NewBifrostContext(context.Background(), <-time.After(1*time.Minute)) + capturedCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test") + capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-eh-custom"), "custom-value") + capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-prom-env"), "production") var done atomic.Bool operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { @@ -161,16 +157,18 @@ func TestSubmitJob_EmptyContextValues(t *testing.T) { func TestSubmitJob_AsyncFlagOverridesContextValues(t *testing.T) { executor := newTestAsyncExecutor(t) + inputCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + inputCtx.SetValue(schemas.BifrostIsAsyncRequest, false) + var capturedCtx *schemas.BifrostContext var done atomic.Bool - capturedCtx.SetValue(schemas.BifrostIsAsyncRequest, false) operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { capturedCtx = bgCtx done.Store(true) return map[string]string{"status": "ok"}, nil } - job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest) + job, err := executor.SubmitJob(inputCtx, 3600, operation, schemas.ChatCompletionRequest) require.NoError(t, err) require.NotNil(t, job) @@ -183,8 +181,10 @@ func TestSubmitJob_AsyncFlagOverridesContextValues(t *testing.T) { func TestSubmitJob_OperationFailure_PreservesContext(t *testing.T) { executor := newTestAsyncExecutor(t) + inputCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + inputCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test") + var capturedCtx *schemas.BifrostContext - capturedCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test") var done atomic.Bool statusCode := fasthttp.StatusBadRequest @@ -197,7 +197,7 @@ func TestSubmitJob_OperationFailure_PreservesContext(t *testing.T) { } } - job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest) + job, err := executor.SubmitJob(inputCtx, 3600, operation, schemas.ChatCompletionRequest) require.NoError(t, err) require.NotNil(t, job) diff --git a/framework/logstore/postgres.go b/framework/logstore/postgres.go index df78b1735d..183d466554 100644 --- a/framework/logstore/postgres.go +++ b/framework/logstore/postgres.go @@ -24,6 +24,13 @@ type PostgresConfig struct { } // newPostgresLogStore creates a new Postgres log store. +// +// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not +// change result type"): a throwaway pool runs the version check and schema +// migrations and is closed immediately, then a fresh runtime pool is opened +// for query traffic and the async index / matview builders. The runtime +// pool's connections never see pre-migration schema, so their cached +// prepared-plans stay valid for the life of the process. func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (LogStore, error) { if config == nil { return nil, fmt.Errorf("config is required") @@ -48,11 +55,56 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch return nil, fmt.Errorf("postgres ssl mode is required") } dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue()) - db, err := gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - }), &gorm.Config{ - Logger: newGormLogger(logger), - }) + + openPool := func() (*gorm.DB, error) { + return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{ + Logger: newGormLogger(logger), + }) + } + + // closePoolStrict returns the close error so callers can abort startup + // when the throwaway migration pool doesn't tear down cleanly — a half- + // closed pool weakens the guarantee that no cached plans survive DDL. + closePool := func(db *gorm.DB) error { + if db == nil { + return nil + } + sqlDB, err := db.DB() + if err != nil { + return err + } + return sqlDB.Close() + } + + // Throwaway pool for the version gate and schema migrations. Closing it + // before the runtime pool opens guarantees no cached plan survives DDL. + mDb, err := openPool() + if err != nil { + return nil, err + } + + // Postgres version gate: refuse to start below 16 (matviews, partitioning, + // and some JSON operators we rely on depend on 16+). + var pgVersionNum int + if err := mDb.Raw("SELECT current_setting('server_version_num')::int").Scan(&pgVersionNum).Error; err != nil { + _ = closePool(mDb) + return nil, err + } + if pgVersionNum < 160000 { + _ = closePool(mDb) + return nil, fmt.Errorf("postgres version is lower than 16, please upgrade to 16 or higher") + } + + if err := triggerMigrations(ctx, mDb); err != nil { + _ = closePool(mDb) + return nil, err + } + if err := closePool(mDb); err != nil { + return nil, fmt.Errorf("close migration db connection: %w", err) + } + + // Runtime pool. Opens against post-migration schema. + db, err := openPool() if err != nil { return nil, err } @@ -60,6 +112,7 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch // Configure connection pool sqlDB, err := db.DB() if err != nil { + closePool(db) return nil, err } // Set MaxIdleConns (default: 5) @@ -77,25 +130,6 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch sqlDB.SetMaxOpenConns(maxOpenConns) d := &RDBLogStore{db: db, logger: logger} - // Check version of postgres, if is lower than 16, throw fatal error - var pgVersionNum int - if err := db.Raw("SELECT current_setting('server_version_num')::int").Scan(&pgVersionNum).Error; err != nil { - sqlDB.Close() - return nil, err - } - if pgVersionNum < 160000 { - sqlDB.Close() - return nil, fmt.Errorf("postgres version is lower than 16, please upgrade to 16 or higher") - } - - // Run migrations - if err := triggerMigrations(ctx, db); err != nil { - if sqlDB, sqlErr := db.DB(); sqlErr == nil { - sqlDB.Close() - } - return nil, err - } - // Run all index builds sequentially in a single goroutine to prevent // deadlocks from concurrent CREATE INDEX CONCURRENTLY on the same table. // Each function is idempotent and acquires its own advisory lock for diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 26bbdd1a44..265d4a8d65 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -398,7 +398,7 @@ func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req // Process virtual key if provided if virtualKeyValue != nil { - virtualKey, ok = p.store.GetVirtualKey(*virtualKeyValue) + virtualKey, ok = p.store.GetVirtualKey(ctx, *virtualKeyValue) if !ok || virtualKey == nil || !virtualKey.IsActive { return nil, nil } @@ -496,7 +496,7 @@ func (p *GovernancePlugin) governLargePayload(ctx *schemas.BifrostContext, req * // Process virtual key if provided var virtualKey *configstoreTables.TableVirtualKey if virtualKeyValue != nil { - vk, ok := p.store.GetVirtualKey(*virtualKeyValue) + vk, ok := p.store.GetVirtualKey(ctx, *virtualKeyValue) if !ok || vk == nil || !vk.IsActive { return nil, nil } @@ -1061,7 +1061,7 @@ func (p *GovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext // Checking if the virtual key is valid or not isVirtualKeyValid := false if evaluationRequest.VirtualKey != "" { - _, exists := p.store.GetVirtualKey(evaluationRequest.VirtualKey) + _, exists := p.store.GetVirtualKey(ctx, evaluationRequest.VirtualKey) if exists { isVirtualKeyValid = true } else { @@ -1095,22 +1095,70 @@ func (p *GovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext // First evaluate model and provider checks (applies even when virtual keys are disabled or not present) result := p.resolver.EvaluateModelAndProviderRequest(ctx, evaluationRequest.Provider, evaluationRequest.Model) - // Check user-level governance (enterprise-only, runs before VK checks) - if result.Decision == DecisionAllow { - result = p.resolver.EvaluateUserRequest(ctx, evaluationRequest.UserID, evaluationRequest) + // The flow for governance checks is: + // VK (identity + VK-level budget/rate-limit) -> Customer -> Team -> User + // VK identity runs FIRST so that revoked, provider-disallowed, or model-disallowed + // keys are blocked before any hierarchy state is consulted. Running Customer/Team/ + // User ahead of VK would leak topology: a revoked key attached to an over-budget + // team would return 429 team-budget-exceeded instead of 403 VK-blocked, telling + // an attacker the key's team structure was validated. + + // Resolve the VK once; it feeds both the VK evaluation and hierarchy-ID extraction. + var hierarchyVK *configstoreTables.TableVirtualKey + if evaluationRequest.VirtualKey != "" { + if vk, ok := p.store.GetVirtualKey(ctx, evaluationRequest.VirtualKey); ok && vk != nil { + hierarchyVK = vk + } } - // If model/provider checks passed, evaluate virtual key + // Step 1: Evaluate virtual key (identity + VK-level budget/rate-limit hierarchy). + // Short-circuits with VirtualKeyBlocked / ProviderBlocked / ModelBlocked before + // we touch Customer / Team / User. if result.Decision == DecisionAllow && evaluationRequest.VirtualKey != "" { - if evaluationRequest.UserID != "" { - // User auth present: only use VK for routing/filtering (skip rate limits and budgets) - result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, true) - } else { - // No user auth: full VK governance (routing + limits) - result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, false) + skipVKBudgetLimit := evaluationRequest.UserID != "" + result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, skipVKBudgetLimit) + } + + // Step 2: Customer-level budget (customer attached directly to VK, or via the VK's team). + // Fall back to the loaded relation IDs so VKs populated via joins without FK + // pointer columns still participate in customer-level enforcement. + if result.Decision == DecisionAllow && hierarchyVK != nil { + var customerID string + switch { + case hierarchyVK.CustomerID != nil: + customerID = *hierarchyVK.CustomerID + case hierarchyVK.Customer != nil: + customerID = hierarchyVK.Customer.ID + case hierarchyVK.Team != nil && hierarchyVK.Team.CustomerID != nil: + customerID = *hierarchyVK.Team.CustomerID + case hierarchyVK.Team != nil && hierarchyVK.Team.Customer != nil: + customerID = hierarchyVK.Team.Customer.ID + } + if customerID != "" { + result = p.resolver.EvaluateCustomerRequest(ctx, customerID, evaluationRequest) } } + // Step 3: Team-level budget. Fall back to vk.Team.ID when the FK pointer is nil + // but the relation is populated. + if result.Decision == DecisionAllow && hierarchyVK != nil { + var teamID string + switch { + case hierarchyVK.TeamID != nil: + teamID = *hierarchyVK.TeamID + case hierarchyVK.Team != nil: + teamID = hierarchyVK.Team.ID + } + if teamID != "" { + result = p.resolver.EvaluateTeamRequest(ctx, teamID, evaluationRequest) + } + } + + // Step 4: User-level governance (enterprise-only). + if result.Decision == DecisionAllow { + result = p.resolver.EvaluateUserRequest(ctx, evaluationRequest.UserID, evaluationRequest) + } + // Check the actual MCP tools injected into the request against the VK MCPConfigs. // BifrostContextKeyMCPAddedTools is populated by AddToolsToRequest (which runs before // PreLLMHook), so it contains the real expanded tool names (e.g. "youtube-search") rather @@ -1315,7 +1363,7 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche if requestType == schemas.ListModelsRequest && result != nil && result.ListModelsResponse != nil && virtualKey != "" { // filter models which are not supported on this virtual key - result.ListModelsResponse.Data = p.filterModelsForVirtualKey(result.ListModelsResponse.Data, virtualKey) + result.ListModelsResponse.Data = p.filterModelsForVirtualKey(ctx, result.ListModelsResponse.Data, virtualKey) } isFinalChunk := bifrost.IsFinalChunk(ctx) @@ -1389,7 +1437,7 @@ func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas. // Blind single-tool check: validate the specific tool being executed against VK MCPConfigs. // This runs independently of EvaluateGovernanceRequest to enforce execution-time allow-list. if virtualKeyValue != "" { - vk, ok := p.store.GetVirtualKey(virtualKeyValue) + vk, ok := p.store.GetVirtualKey(ctx, virtualKeyValue) if !ok || vk == nil || !vk.IsActive { // VK became invalid after initial check - fail closed for security ctx.SetValue(governanceRejectedContextKey, true) diff --git a/plugins/governance/modelprovidergovernance_test.go b/plugins/governance/modelprovidergovernance_test.go index 80a6c66ffa..f0fb5b1d2e 100644 --- a/plugins/governance/modelprovidergovernance_test.go +++ b/plugins/governance/modelprovidergovernance_test.go @@ -21,7 +21,7 @@ func TestStore_CheckProviderBudget_NoConfig(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err, "Should allow when no provider config exists") } @@ -33,7 +33,7 @@ func TestStore_CheckProviderBudget_NoBudget(t *testing.T) { }, nil) require.NoError(t, err) - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err, "Should allow when provider has no budget") } @@ -47,7 +47,7 @@ func TestStore_CheckProviderBudget_WithinLimit(t *testing.T) { }, nil) require.NoError(t, err) - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err, "Should allow when budget is within limit") } @@ -61,7 +61,7 @@ func TestStore_CheckProviderBudget_Exceeded(t *testing.T) { }, nil) require.NoError(t, err) - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.Error(t, err, "Should reject when budget is exceeded") assert.Contains(t, err.Error(), "budget exceeded") } @@ -78,7 +78,7 @@ func TestStore_CheckProviderBudget_WithBaseline(t *testing.T) { // With baseline that would exceed limit baselines := map[string]float64{"budget1": 15.0} - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, baselines) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, baselines) assert.Error(t, err, "Should reject when current usage + baseline exceeds limit") assert.Contains(t, err.Error(), "budget exceeded") } @@ -92,7 +92,7 @@ func TestStore_CheckProviderRateLimit_NoConfig(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.NoError(t, err, "Should allow when no provider config exists") assert.Equal(t, DecisionAllow, decision) } @@ -105,7 +105,7 @@ func TestStore_CheckProviderRateLimit_NoRateLimit(t *testing.T) { }, nil) require.NoError(t, err) - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.NoError(t, err, "Should allow when provider has no rate limit") assert.Equal(t, DecisionAllow, decision) } @@ -120,7 +120,7 @@ func TestStore_CheckProviderRateLimit_TokenLimitExceeded(t *testing.T) { }, nil) require.NoError(t, err) - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.Error(t, err, "Should reject when provider token limit is exceeded") assert.Equal(t, DecisionTokenLimited, decision) assert.Contains(t, err.Error(), "token limit exceeded") @@ -136,7 +136,7 @@ func TestStore_CheckProviderRateLimit_RequestLimitExceeded(t *testing.T) { }, nil) require.NoError(t, err) - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.Error(t, err, "Should reject when provider request limit is exceeded") assert.Equal(t, DecisionRequestLimited, decision) assert.Contains(t, err.Error(), "request limit exceeded") @@ -152,7 +152,7 @@ func TestStore_CheckProviderRateLimit_BothLimitsExceeded(t *testing.T) { }, nil) require.NoError(t, err) - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.Error(t, err, "Should reject when both provider token and request limits are exceeded") assert.Equal(t, DecisionRateLimited, decision) // General rate limited when both are exceeded assert.Contains(t, err.Error(), "rate limit") @@ -168,7 +168,7 @@ func TestStore_CheckProviderRateLimit_WithinLimits(t *testing.T) { }, nil) require.NoError(t, err) - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.NoError(t, err, "Should allow when provider rate limits are within limits") assert.Equal(t, DecisionAllow, decision) } @@ -183,7 +183,7 @@ func TestStore_CheckModelBudget_NoConfig(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.NoError(t, err, "Should allow when no model config exists") } @@ -198,7 +198,7 @@ func TestStore_CheckModelBudget_ModelOnly_WithinLimit(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.NoError(t, err, "Should allow when model budget is within limit") } @@ -213,7 +213,7 @@ func TestStore_CheckModelBudget_ModelOnly_Exceeded(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.Error(t, err, "Should reject when model budget is exceeded") assert.Contains(t, err.Error(), "budget exceeded") } @@ -230,7 +230,7 @@ func TestStore_CheckModelBudget_ModelWithProvider_WithinLimit(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.NoError(t, err, "Should allow when model+provider budget is within limit") } @@ -246,7 +246,7 @@ func TestStore_CheckModelBudget_ModelWithProvider_Exceeded(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.Error(t, err, "Should reject when model+provider budget is exceeded") assert.Contains(t, err.Error(), "budget exceeded") } @@ -267,7 +267,7 @@ func TestStore_CheckModelBudget_BothModelAndModelProvider_ChecksBoth(t *testing. require.NoError(t, err) provider := schemas.OpenAI - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.Error(t, err, "Should reject when model-only budget is exceeded, even if model+provider budget is OK") assert.Contains(t, err.Error(), "budget exceeded") } @@ -286,7 +286,7 @@ func TestStore_CheckModelBudget_ProviderSpecific_DifferentProvider_Passes(t *tes // Request with Azure (different provider) for same model should pass provider := schemas.Azure - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: provider}, nil) assert.NoError(t, err, "Should allow when model config is provider-specific and different provider is used") } @@ -300,7 +300,7 @@ func TestStore_CheckModelRateLimit_NoConfig(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when no model config exists") assert.Equal(t, DecisionAllow, decision) } @@ -316,7 +316,7 @@ func TestStore_CheckModelRateLimit_ModelOnly_TokenLimitExceeded(t *testing.T) { require.NoError(t, err) provider := schemas.OpenAI - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model token limit is exceeded") assert.Equal(t, DecisionTokenLimited, decision) assert.Contains(t, err.Error(), "token limit exceeded") @@ -333,7 +333,7 @@ func TestStore_CheckModelRateLimit_ModelOnly_RequestLimitExceeded(t *testing.T) require.NoError(t, err) provider := schemas.OpenAI - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model request limit is exceeded") assert.Equal(t, DecisionRequestLimited, decision) assert.Contains(t, err.Error(), "request limit exceeded") @@ -351,7 +351,7 @@ func TestStore_CheckModelRateLimit_ModelWithProvider_WithinLimits(t *testing.T) require.NoError(t, err) provider := schemas.OpenAI - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when model+provider rate limits are within limits") assert.Equal(t, DecisionAllow, decision) } @@ -372,7 +372,7 @@ func TestStore_CheckModelRateLimit_BothModelAndModelProvider_ChecksBoth(t *testi require.NoError(t, err) provider := schemas.OpenAI - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model-only rate limit is exceeded") assert.Equal(t, DecisionTokenLimited, decision) assert.Contains(t, err.Error(), "token limit exceeded") @@ -394,7 +394,7 @@ func TestStore_CheckModelRateLimit_BothModelAndModelProvider_ChecksBoth_RequestL require.NoError(t, err) provider := schemas.OpenAI - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model-only rate limit (request limit) is exceeded") assert.Equal(t, DecisionRequestLimited, decision) assert.Contains(t, err.Error(), "request limit exceeded") @@ -414,7 +414,7 @@ func TestStore_CheckModelRateLimit_ProviderSpecific_DifferentProvider_Passes(t * // Request with Azure (different provider) for same model should pass provider := schemas.Azure - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when model config is provider-specific and different provider is used") assert.Equal(t, DecisionAllow, decision) } @@ -433,7 +433,7 @@ func TestStore_CheckModelRateLimit_ProviderSpecific_DifferentProvider_Passes_Req // Request with Azure (different provider) for same model should pass provider := schemas.Azure - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when model config is provider-specific and different provider is used (request limit)") assert.Equal(t, DecisionAllow, decision) } @@ -465,7 +465,7 @@ func TestStore_UpdateProviderBudgetUsage_UpdatesUsage(t *testing.T) { assert.NoError(t, err, "Should successfully update provider budget usage") // Verify usage was updated - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err, "Should still be within limit after first update") // Update again to exceed @@ -473,7 +473,7 @@ func TestStore_UpdateProviderBudgetUsage_UpdatesUsage(t *testing.T) { assert.NoError(t, err, "Should successfully update provider budget usage even when exceeding") // Now should be exceeded - err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckProviderBudget(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.Error(t, err, "Should be exceeded after second update") assert.Contains(t, err.Error(), "budget exceeded") } @@ -505,7 +505,7 @@ func TestStore_UpdateProviderRateLimitUsage_UpdatesTokens(t *testing.T) { assert.NoError(t, err, "Should successfully update provider token usage") // Check that tokens were updated but requests were not - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.NoError(t, err, "Should still be within token limit") assert.Equal(t, DecisionAllow, decision) @@ -514,7 +514,7 @@ func TestStore_UpdateProviderRateLimitUsage_UpdatesTokens(t *testing.T) { assert.NoError(t, err, "Should successfully update provider token usage even when exceeding") // Now should be exceeded - err, decision = store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err = store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.Error(t, err, "Should reject when provider token limit is exceeded after update") assert.Equal(t, DecisionTokenLimited, decision) assert.Contains(t, err.Error(), "token limit exceeded") @@ -537,7 +537,7 @@ func TestStore_UpdateProviderRateLimitUsage_UpdatesRequests(t *testing.T) { } // Should still be within limit - err, decision := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err := store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.NoError(t, err, "Should allow when provider request limit is within limit") assert.Equal(t, DecisionAllow, decision) @@ -548,7 +548,7 @@ func TestStore_UpdateProviderRateLimitUsage_UpdatesRequests(t *testing.T) { } // Now should be exceeded - err, decision = store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) + decision, err = store.CheckProviderRateLimit(context.Background(), &EvaluationRequest{Provider: schemas.OpenAI}, nil, nil) assert.Error(t, err, "Should reject when provider request limit is exceeded after update") assert.Equal(t, DecisionRequestLimited, decision) assert.Contains(t, err.Error(), "request limit exceeded") @@ -583,7 +583,7 @@ func TestStore_UpdateModelBudgetUsage_ModelOnly_UpdatesUsage(t *testing.T) { assert.NoError(t, err, "Should successfully update model budget usage") // Verify usage was updated - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.NoError(t, err, "Should still be within limit after first update") // Update again to exceed @@ -591,7 +591,7 @@ func TestStore_UpdateModelBudgetUsage_ModelOnly_UpdatesUsage(t *testing.T) { assert.NoError(t, err, "Should successfully update model budget usage even when exceeding") // Now should be exceeded - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.Error(t, err, "Should be exceeded after second update") assert.Contains(t, err.Error(), "budget exceeded") } @@ -617,7 +617,7 @@ func TestStore_UpdateModelBudgetUsage_ModelWithProvider_UpdatesBoth(t *testing.T // Both budgets should be updated // Check model-only budget - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.NoError(t, err, "Should still be within limit") // Update to exceed model-only budget @@ -625,7 +625,7 @@ func TestStore_UpdateModelBudgetUsage_ModelWithProvider_UpdatesBoth(t *testing.T assert.NoError(t, err, "Should successfully update model budget usage even when exceeding") // Now model-only budget should be exceeded - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil) assert.Error(t, err, "Should be exceeded when model-only budget is exceeded") assert.Contains(t, err.Error(), "budget exceeded") } @@ -659,7 +659,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelOnly_UpdatesUsage(t *testing.T) { assert.NoError(t, err, "Should successfully update model token usage") // Should still be within limit - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when model token limit is within limit") assert.Equal(t, DecisionAllow, decision) @@ -668,7 +668,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelOnly_UpdatesUsage(t *testing.T) { assert.NoError(t, err, "Should successfully update model token usage even when exceeding") // Now should be exceeded - err, decision = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model token limit is exceeded after update") assert.Equal(t, DecisionTokenLimited, decision) assert.Contains(t, err.Error(), "token limit exceeded") @@ -694,7 +694,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelWithProvider_UpdatesUsage(t *testi assert.NoError(t, err, "Should successfully update both model-only and model+provider token usage") // Should still be within limit - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when both rate limits are within limit") assert.Equal(t, DecisionAllow, decision) @@ -703,7 +703,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelWithProvider_UpdatesUsage(t *testi assert.NoError(t, err, "Should successfully update model token usage even when exceeding") // Now should be exceeded (model-only rate limit exceeded) - err, decision = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model-only token limit is exceeded after update") assert.Equal(t, DecisionTokenLimited, decision) assert.Contains(t, err.Error(), "token limit exceeded") @@ -727,7 +727,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelOnly_UpdatesUsage_RequestLimit(t * } // Should still be within limit - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when model request limit is within limit") assert.Equal(t, DecisionAllow, decision) @@ -738,7 +738,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelOnly_UpdatesUsage_RequestLimit(t * } // Now should be exceeded - err, decision = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model request limit is exceeded after update") assert.Equal(t, DecisionRequestLimited, decision) assert.Contains(t, err.Error(), "request limit exceeded") @@ -767,7 +767,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelWithProvider_UpdatesUsage_RequestL } // Should still be within limit - err, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.NoError(t, err, "Should allow when both rate limits are within limit") assert.Equal(t, DecisionAllow, decision) @@ -778,7 +778,7 @@ func TestStore_UpdateModelRateLimitUsage_ModelWithProvider_UpdatesUsage_RequestL } // Now should be exceeded (model-only rate limit exceeded) - err, decision = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) + decision, err = store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "gpt-4", Provider: provider}, nil, nil) assert.Error(t, err, "Should reject when model-only request limit is exceeded after update") assert.Equal(t, DecisionRequestLimited, decision) assert.Contains(t, err.Error(), "request limit exceeded") @@ -1958,7 +1958,7 @@ func TestStore_CheckModelBudget_CrossProviderModelMatch(t *testing.T) { require.NoError(t, err) // Request with provider-prefixed model name should match the "gpt-4o" config - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) assert.Error(t, err, "Should reject: openai/gpt-4o should match model-only config for gpt-4o") assert.Contains(t, err.Error(), "budget exceeded") } @@ -1977,7 +1977,7 @@ func TestStore_CheckModelBudget_CrossProviderModelMatch_WithinLimit(t *testing.T }, mc) require.NoError(t, err) - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) assert.NoError(t, err, "Should allow: budget is within limit") } @@ -1995,7 +1995,7 @@ func TestStore_CheckModelRateLimit_CrossProviderModelMatch(t *testing.T) { }, mc) require.NoError(t, err) - errResult, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil, nil) + decision, errResult := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil, nil) assert.Error(t, errResult, "Should reject: openai/gpt-4o should match model-only rate limit for gpt-4o") assert.Contains(t, errResult.Error(), "token limit exceeded") assert.NotEqual(t, DecisionAllow, decision) @@ -2024,7 +2024,7 @@ func TestStore_UpdateModelBudgetUsage_CrossProviderModelMatch(t *testing.T) { assert.NoError(t, err) // Budget should now be exceeded - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) assert.Error(t, err, "Budget should be exceeded after usage updates via cross-provider match") assert.Contains(t, err.Error(), "budget exceeded") } @@ -2048,7 +2048,7 @@ func TestStore_UpdateModelRateLimitUsage_CrossProviderModelMatch(t *testing.T) { assert.NoError(t, err, "Should successfully update rate limit via cross-provider match") // Rate limit should now be exceeded - errResult, decision := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil, nil) + decision, errResult := store.CheckModelRateLimit(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil, nil) assert.Error(t, errResult, "Token limit should be exceeded after usage update via cross-provider match") assert.Contains(t, errResult.Error(), "token limit exceeded") assert.NotEqual(t, DecisionAllow, decision) @@ -2070,11 +2070,11 @@ func TestStore_CheckModelBudget_ModelWithProvider_ExactMatchOnly(t *testing.T) { require.NoError(t, err) // Request with the exact matching model+provider should be rejected (budget exceeded) - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: schemas.OpenAI}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: schemas.OpenAI}, nil) assert.Error(t, err, "Exact model+provider match should apply budget") // Request with a different provider should NOT match the provider-specific config - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: schemas.OpenRouter}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: schemas.OpenRouter}, nil) assert.NoError(t, err, "Different provider should not match provider-specific config") } @@ -2093,10 +2093,10 @@ func TestStore_CheckModelBudget_NoCatalog_NoMatch(t *testing.T) { require.NoError(t, err) // Without catalog, "openai/gpt-4o" won't match "gpt-4o" config - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "openai/gpt-4o", Provider: schemas.OpenRouter}, nil) assert.NoError(t, err, "Without model catalog, cross-provider matching should not happen") // Direct match should still work - err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: schemas.OpenAI}, nil) + _, err = store.CheckModelBudget(context.Background(), &EvaluationRequest{Model: "gpt-4o", Provider: schemas.OpenAI}, nil) assert.Error(t, err, "Direct match should still work without catalog") } diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 91e154c0e4..84a13103bd 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -88,34 +88,34 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon } // 1. Check provider-level rate limits FIRST (before model-level checks) if provider != "" { - if err, decision := r.store.CheckProviderRateLimit(ctx, request, nil, nil); err != nil { + if decision, err := r.store.CheckProviderRateLimit(ctx, request, nil, nil); err != nil || isRateLimitViolation(decision) { return &EvaluationResult{ Decision: decision, - Reason: fmt.Sprintf("Provider-level rate limit check failed: %s", err.Error()), + Reason: fmt.Sprintf("Provider-level rate limit check failed: %s", reasonFromErr(err, decision)), } } // 2. Check provider-level budgets FIRST (before model-level checks) - if err := r.store.CheckProviderBudget(ctx, request, nil); err != nil { + if decision, err := r.store.CheckProviderBudget(ctx, request, nil); err != nil || isBudgetViolation(decision) { return &EvaluationResult{ - Decision: DecisionBudgetExceeded, - Reason: fmt.Sprintf("Provider-level budget exceeded: %s", err.Error()), + Decision: decision, + Reason: fmt.Sprintf("Provider-level budget exceeded: %s", reasonFromErr(err, decision)), } } } // 3. Check model-level rate limits (after provider-level checks) if model != "" { - if err, decision := r.store.CheckModelRateLimit(ctx, request, nil, nil); err != nil { + if decision, err := r.store.CheckModelRateLimit(ctx, request, nil, nil); err != nil || isRateLimitViolation(decision) { return &EvaluationResult{ Decision: decision, - Reason: fmt.Sprintf("Model-level rate limit check failed: %s", err.Error()), + Reason: fmt.Sprintf("Model-level rate limit check failed: %s", reasonFromErr(err, decision)), } } // 4. Check model-level budgets (after provider-level checks) - if err := r.store.CheckModelBudget(ctx, request, nil); err != nil { + if decision, err := r.store.CheckModelBudget(ctx, request, nil); err != nil || isBudgetViolation(decision) { return &EvaluationResult{ - Decision: DecisionBudgetExceeded, - Reason: fmt.Sprintf("Model-level budget exceeded: %s", err.Error()), + Decision: decision, + Reason: fmt.Sprintf("Model-level budget exceeded: %s", reasonFromErr(err, decision)), } } } @@ -126,6 +126,67 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon } } +func (r *BudgetResolver) EvaluateCustomerRequest(ctx *schemas.BifrostContext, customerID string, request *EvaluationRequest) *EvaluationResult { + // Skip if no customerID + if customerID == "" { + return &EvaluationResult{ + Decision: DecisionAllow, + Reason: "No customer ID provided, skipping customer-level checks", + } + } + // Check customer-level rate limits + if decision, err := r.store.CheckCustomerRateLimit(ctx, customerID, request, nil, nil); err != nil || isRateLimitViolation(decision) { + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Customer-level rate limit exceeded: %s", reasonFromErr(err, decision)), + } + } + + // Check customer-level budget + if decision, err := r.store.CheckCustomerBudget(ctx, customerID, request, nil); err != nil || isBudgetViolation(decision) { + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Customer-level budget exceeded: %s", reasonFromErr(err, decision)), + } + } + + return &EvaluationResult{ + Decision: DecisionAllow, + Reason: "Customer-level checks passed", + } +} + +func (r *BudgetResolver) EvaluateTeamRequest(ctx *schemas.BifrostContext, teamID string, request *EvaluationRequest) *EvaluationResult { + // Skip if no teamID + if teamID == "" { + return &EvaluationResult{ + Decision: DecisionAllow, + Reason: "No team ID provided, skipping team-level checks", + } + } + // Check team-level rate limits + if decision, err := r.store.CheckTeamRateLimit(ctx, teamID, request, nil, nil); err != nil || isRateLimitViolation(decision) { + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Team-level rate limit exceeded: %s", reasonFromErr(err, decision)), + } + } + + // Check team-level budget + if decision, err := r.store.CheckTeamBudget(ctx, teamID, request, nil); err != nil || isBudgetViolation(decision) { + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Team-level budget exceeded: %s", reasonFromErr(err, decision)), + } + } + + return &EvaluationResult{ + Decision: DecisionAllow, + Reason: "Team-level checks passed", + } + +} + // EvaluateUserRequest evaluates user-level rate limits and budgets (enterprise-only) // This runs after provider/model checks but before VK checks // Returns DecisionAllow if userID is empty or user has no governance configured @@ -139,18 +200,18 @@ func (r *BudgetResolver) EvaluateUserRequest(ctx *schemas.BifrostContext, userID } // Check user-level rate limits - if err, decision := r.store.CheckUserRateLimit(ctx, userID, request, nil, nil); err != nil { + if decision, err := r.store.CheckUserRateLimit(ctx, userID, request, nil, nil); err != nil || isRateLimitViolation(decision) { return &EvaluationResult{ Decision: decision, - Reason: fmt.Sprintf("User-level rate limit exceeded: %s", err.Error()), + Reason: fmt.Sprintf("User-level rate limit exceeded: %s", reasonFromErr(err, decision)), } } // Check user-level budget - if err := r.store.CheckUserBudget(ctx, userID, request, nil); err != nil { + if decision, err := r.store.CheckUserBudget(ctx, userID, request, nil); err != nil || isBudgetViolation(decision) { return &EvaluationResult{ - Decision: DecisionBudgetExceeded, - Reason: fmt.Sprintf("User-level budget exceeded: %s", err.Error()), + Decision: decision, + Reason: fmt.Sprintf("User-level budget exceeded: %s", reasonFromErr(err, decision)), } } @@ -175,7 +236,7 @@ func (r *BudgetResolver) isModelRequired(requestType schemas.RequestType) bool { // skipRateLimitsAndBudgets evaluates to true when we want to skip rate limits and budgets. This is used when user auth is present (user governance handles limits). func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestType schemas.RequestType, skipRateLimitsAndBudgets bool) *EvaluationResult { // 1. Validate virtual key exists and is active - vk, exists := r.store.GetVirtualKey(virtualKeyValue) + vk, exists := r.store.GetVirtualKey(ctx, virtualKeyValue) if !exists { return &EvaluationResult{ Decision: DecisionVirtualKeyNotFound, @@ -308,7 +369,7 @@ func (r *BudgetResolver) isProviderAllowed(vk *configstoreTables.TableVirtualKey // checkRateLimitHierarchy checks provider-level rate limits first, then VK rate limits using flexible approach func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest) *EvaluationResult { - if decision, err := r.store.CheckRateLimit(ctx, vk, request, nil, nil); err != nil { + if decision, err := r.store.CheckVirtualKeyRateLimit(ctx, vk, request, nil, nil); err != nil || isRateLimitViolation(decision) { // Check provider-level first (matching check order), then VK-level var rateLimitInfo *configstoreTables.TableRateLimit for _, pc := range vk.ProviderConfigs { @@ -322,7 +383,7 @@ func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *config } return &EvaluationResult{ Decision: decision, - Reason: fmt.Sprintf("Rate limit check failed: %s", err.Error()), + Reason: fmt.Sprintf("Rate limit check failed: %s", reasonFromErr(err, decision)), VirtualKey: vk, RateLimitInfo: rateLimitInfo, } @@ -334,16 +395,14 @@ func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *config // checkBudgetHierarchy checks the budget hierarchy atomically (VK → Team → Customer) func (r *BudgetResolver) checkBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest) *EvaluationResult { // Use atomic budget checking to prevent race conditions - if err := r.store.CheckBudget(ctx, vk, request, nil); err != nil { - r.logger.Debug(fmt.Sprintf("Atomic budget exceeded for VK %s: %s", vk.ID, err.Error())) - + if decision, err := r.store.CheckVirtualKeyBudget(ctx, vk, request, nil); err != nil || isBudgetViolation(decision) { + r.logger.Debug(fmt.Sprintf("Atomic budget exceeded for VK %s: %s", vk.ID, reasonFromErr(err, decision))) return &EvaluationResult{ - Decision: DecisionBudgetExceeded, - Reason: fmt.Sprintf("Budget exceeded: %s", err.Error()), + Decision: decision, + Reason: fmt.Sprintf("Budget exceeded: %s", reasonFromErr(err, decision)), VirtualKey: vk, } } - return nil // No budget violations } @@ -354,7 +413,7 @@ func (r *BudgetResolver) isProviderBudgetViolated(ctx context.Context, vk *confi request := &EvaluationRequest{Provider: schemas.ModelProvider(config.Provider)} // 1. Check global provider-level budget first - if err := r.store.CheckProviderBudget(ctx, request, nil); err != nil { + if _, err := r.store.CheckProviderBudget(ctx, request, nil); err != nil { r.logger.Debug(fmt.Sprintf("Global provider budget exceeded for provider %s: %s", config.Provider, err.Error())) return true } @@ -363,7 +422,7 @@ func (r *BudgetResolver) isProviderBudgetViolated(ctx context.Context, vk *confi if len(config.Budgets) == 0 { return false } - if err := r.store.CheckBudget(ctx, vk, request, nil); err != nil { + if _, err := r.store.CheckVirtualKeyBudget(ctx, vk, request, nil); err != nil { r.logger.Debug(fmt.Sprintf("VK provider config budget exceeded for VK %s: %s", vk.ID, err.Error())) return true } @@ -375,7 +434,7 @@ func (r *BudgetResolver) isProviderRateLimitViolated(ctx context.Context, vk *co request := &EvaluationRequest{Provider: schemas.ModelProvider(config.Provider)} // 1. Check global provider-level rate limit first - if err, decision := r.store.CheckProviderRateLimit(ctx, request, nil, nil); err != nil || isRateLimitViolation(decision) { + if decision, err := r.store.CheckProviderRateLimit(ctx, request, nil, nil); err != nil || isRateLimitViolation(decision) { r.logger.Debug(fmt.Sprintf("Global provider rate limit exceeded for provider %s", config.Provider)) return true } @@ -384,7 +443,7 @@ func (r *BudgetResolver) isProviderRateLimitViolated(ctx context.Context, vk *co if config.RateLimit == nil { return false } - decision, err := r.store.CheckRateLimit(ctx, vk, request, nil, nil) + decision, err := r.store.CheckVirtualKeyRateLimit(ctx, vk, request, nil, nil) if err != nil || isRateLimitViolation(decision) { r.logger.Debug(fmt.Sprintf("VK provider config rate limit exceeded for VK %s, provider %s", vk.ID, config.Provider)) return true @@ -396,3 +455,18 @@ func (r *BudgetResolver) isProviderRateLimitViolated(ctx context.Context, vk *co func isRateLimitViolation(decision Decision) bool { return decision == DecisionRateLimited || decision == DecisionTokenLimited || decision == DecisionRequestLimited } + +// isBudgetViolation returns true if the decision indicates a budget violation. +func isBudgetViolation(decision Decision) bool { + return decision == DecisionBudgetExceeded +} + +// reasonFromErr yields a non-nil-safe reason string. When the store returns a +// non-allow decision without an accompanying error, err.Error() would panic — +// fall back to a generic phrase that still names the decision. +func reasonFromErr(err error, decision Decision) string { + if err != nil { + return err.Error() + } + return fmt.Sprintf("policy violation (%s)", decision) +} diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go index d5626df0ee..718e2632a6 100644 --- a/plugins/governance/resolver_test.go +++ b/plugins/governance/resolver_test.go @@ -289,7 +289,7 @@ func TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy(t *testing.T) // Test: VK budget exceeds should fail // Get the governance data to update the budget directly - governanceData := store.GetGovernanceData() + governanceData := store.GetGovernanceData(context.Background()) vkBudgetToUpdate := governanceData.Budgets["vk-budget"] if vkBudgetToUpdate != nil { vkBudgetToUpdate.CurrentUsage = 100.0 diff --git a/plugins/governance/routing.go b/plugins/governance/routing.go index 8e20d5ac48..4df039f12d 100644 --- a/plugins/governance/routing.go +++ b/plugins/governance/routing.go @@ -138,7 +138,7 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi for _, scope := range scopeChain { scopeID := scope.ScopeID - rules := re.store.GetScopedRoutingRules(scope.ScopeName, scopeID) + rules := re.store.GetScopedRoutingRules(ctx, scope.ScopeName, scopeID) re.logger.Debug("[RoutingEngine] Evaluating scope=%s, scopeID=%s, ruleCount=%d", scope.ScopeName, scopeID, len(rules)) if len(rules) == 0 { @@ -154,7 +154,7 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi for _, rule := range rules { re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression) - program, err := re.store.GetRoutingProgram(rule) + program, err := re.store.GetRoutingProgram(ctx, rule) if err != nil { re.logger.Warn("[RoutingEngine] Failed to compile rule %s: %v", rule.Name, err) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: compile error: %v", rule.Name, err)) diff --git a/plugins/governance/routing_test.go b/plugins/governance/routing_test.go index 21e8f7469d..b1816d8dd4 100644 --- a/plugins/governance/routing_test.go +++ b/plugins/governance/routing_test.go @@ -246,7 +246,7 @@ func TestEvaluateRoutingRules_GlobalRuleMatches(t *testing.T) { } // Store the rule - require.NoError(t, store.UpdateRoutingRuleInMemory(rule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule)) // Create routing context ctx := &RoutingContext{ @@ -312,7 +312,7 @@ func TestEvaluateRoutingRules_MultiTargetDeterministicWithPinnedKey(t *testing.T Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule)) routingCtx := &RoutingContext{ Provider: schemas.OpenAI, @@ -364,7 +364,7 @@ func TestEvaluateRoutingRules_ScopePrecedence(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(globalRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), globalRule)) // Create VK-specific rule (should take precedence) vkRule := &configstoreTables.TableRoutingRule{ @@ -379,7 +379,7 @@ func TestEvaluateRoutingRules_ScopePrecedence(t *testing.T) { ScopeID: bifrost.Ptr("vk-123"), Priority: 10, } - require.NoError(t, store.UpdateRoutingRuleInMemory(vkRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), vkRule)) // Create routing context with VirtualKey vk := &configstoreTables.TableVirtualKey{ @@ -428,7 +428,7 @@ func TestEvaluateRoutingRules_PriorityOrdering(t *testing.T) { Scope: "global", Priority: 10, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule1)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule1)) // High precedence rule (evaluated first): lower priority number rule2 := &configstoreTables.TableRoutingRule{ @@ -442,7 +442,7 @@ func TestEvaluateRoutingRules_PriorityOrdering(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule2)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule2)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -477,7 +477,7 @@ func TestResolveRoutingWithFallback_RuleMatches(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -543,7 +543,7 @@ func TestEvaluateRoutingRules_DisabledRulesIgnored(t *testing.T) { Scope: "global", Priority: 10, } - require.NoError(t, store.UpdateRoutingRuleInMemory(disabledRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), disabledRule)) // Create enabled rule enabledRule := &configstoreTables.TableRoutingRule{ @@ -557,7 +557,7 @@ func TestEvaluateRoutingRules_DisabledRulesIgnored(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(enabledRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), enabledRule)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -594,7 +594,7 @@ func TestEvaluateRoutingRules_ComplexExpression(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule)) // Test with matching headers ctx := &RoutingContext{ @@ -638,7 +638,7 @@ func TestEvaluateRoutingRules_NilVirtualKey(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -675,7 +675,7 @@ func TestEvaluateRoutingRules_MissingHeaderGracefully(t *testing.T) { Scope: "global", Priority: 0, } - require.NoError(t, store.UpdateRoutingRuleInMemory(rule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), rule)) // Create context WITHOUT the header ctx := &RoutingContext{ @@ -721,7 +721,7 @@ func TestEvaluateRoutingRules_ChainRuleReEvaluation(t *testing.T) { Priority: 0, ChainRule: true, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleA)) // Rule B: matches gpt-4-turbo → routes to azure/gpt-4, terminal (chain_rule=false). ruleB := &configstoreTables.TableRoutingRule{ @@ -736,7 +736,7 @@ func TestEvaluateRoutingRules_ChainRuleReEvaluation(t *testing.T) { Priority: 1, ChainRule: false, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleB)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleB)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -779,7 +779,7 @@ func TestEvaluateRoutingRules_TerminalRuleStopsChain(t *testing.T) { Priority: 0, ChainRule: false, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleA)) // Rule B: would match gpt-4-turbo, but should never be reached because Rule A is terminal. ruleB := &configstoreTables.TableRoutingRule{ @@ -794,7 +794,7 @@ func TestEvaluateRoutingRules_TerminalRuleStopsChain(t *testing.T) { Priority: 1, ChainRule: false, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleB)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleB)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -836,7 +836,7 @@ func TestEvaluateRoutingRules_ConvergenceStopsChain(t *testing.T) { Priority: 0, ChainRule: true, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleA)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -879,7 +879,7 @@ func TestEvaluateRoutingRules_MaxDepthCutoff(t *testing.T) { Priority: 0, ChainRule: true, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleA)) // Rule B: gpt-4-turbo → azure/gpt-4, chain continues (would proceed to step 2 if depth allowed). ruleB := &configstoreTables.TableRoutingRule{ @@ -894,7 +894,7 @@ func TestEvaluateRoutingRules_MaxDepthCutoff(t *testing.T) { Priority: 1, ChainRule: true, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleB)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleB)) // Rule C: gpt-4 → anthropic/claude-3, would match at step 2 but max depth is 2. ruleC := &configstoreTables.TableRoutingRule{ @@ -909,7 +909,7 @@ func TestEvaluateRoutingRules_MaxDepthCutoff(t *testing.T) { Priority: 2, ChainRule: false, } - require.NoError(t, store.UpdateRoutingRuleInMemory(ruleC)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleC)) ctx := &RoutingContext{ Provider: schemas.OpenAI, @@ -948,12 +948,12 @@ func TestCompileAndCacheProgram_ValidExpression_Routing(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) // Verify caching works - second call should return cached program - cached, err := store.GetRoutingProgram(rule) + cached, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, cached) } @@ -975,7 +975,7 @@ func TestCompileAndCacheProgram_EmptyExpression_Routing(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -997,7 +997,7 @@ func TestCompileAndCacheProgram_InvalidExpression_Routing(t *testing.T) { Enabled: true, } - _, err = store.GetRoutingProgram(rule) + _, err = store.GetRoutingProgram(context.Background(), rule) assert.Error(t, err) } @@ -1008,7 +1008,7 @@ func TestCompileAndCacheProgram_NilRule(t *testing.T) { store, err := NewLocalGovernanceStore(ctx, logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - _, err = store.GetRoutingProgram(nil) + _, err = store.GetRoutingProgram(context.Background(), nil) assert.Error(t, err) assert.Contains(t, err.Error(), "cannot be nil") } @@ -1030,7 +1030,7 @@ func TestCompileAndCacheProgram_ListExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -1052,7 +1052,7 @@ func TestCompileAndCacheProgram_RegexExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -1074,7 +1074,7 @@ func TestCompileAndCacheProgram_HeaderExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -1096,7 +1096,7 @@ func TestCompileAndCacheProgram_RateLimitExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -1118,7 +1118,7 @@ func TestCompileAndCacheProgram_BudgetExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -1140,7 +1140,7 @@ func TestCompileAndCacheProgram_ComplexExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) } @@ -1195,7 +1195,7 @@ func TestEvaluateCELExpression_TrueResult(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) variables := map[string]interface{}{ @@ -1226,7 +1226,7 @@ func TestEvaluateCELExpression_FalseResult(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) variables := map[string]interface{}{ @@ -1257,7 +1257,7 @@ func TestEvaluateCELExpression_ListMembership(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) // Test: model in list @@ -1295,7 +1295,7 @@ func TestEvaluateCELExpression_HeaderAccess(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) variables := map[string]interface{}{ diff --git a/plugins/governance/store.go b/plugins/governance/store.go index 751a9ce0ec..1cd0031d1c 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -18,6 +18,9 @@ import ( "gorm.io/gorm" ) +type EntityWiseBudgets map[string][]*configstoreTables.TableBudget +type EntityWiseRateLimits map[string][]*configstoreTables.TableRateLimit + // LocalGovernanceStore provides in-memory cache for governance data with fast, non-blocking access type LocalGovernanceStore struct { // Core data maps using sync.Map for lock-free reads @@ -29,7 +32,6 @@ type LocalGovernanceStore struct { modelConfigs sync.Map // string -> *ModelConfig (key: "modelName" or "modelName:provider" -> ModelConfig) providers sync.Map // string -> *Provider (Provider name -> Provider with preloaded relationships) routingRules sync.Map // string -> []*TableRoutingRule (key: "scope:scopeID" -> rules, scopeID="" for global) - users sync.Map // string -> *UserGovernance (User ID -> UserGovernance, enterprise-only) // Last DB usages for budgets and rate limits LastDBUsagesBudgetsMu sync.RWMutex // Last DB usages for budgets @@ -65,13 +67,16 @@ type GovernanceData struct { Providers []*configstoreTables.TableProvider `json:"providers"` } +// BusinessUnitGovernance holds in-memory budget and rate limit data for a business unit +type BusinessUnitGovernance struct { + BudgetID *string + RateLimitID *string +} + // UserGovernance holds governance data for a user (enterprise-only) type UserGovernance struct { - UserID string `json:"user_id"` - BudgetID *string `json:"budget_id,omitempty"` - RateLimitID *string `json:"rate_limit_id,omitempty"` - Budget *configstoreTables.TableBudget `json:"budget,omitempty"` - RateLimit *configstoreTables.TableRateLimit `json:"rate_limit,omitempty"` + BudgetID *string `json:"budget_id,omitempty"` + RateLimitID *string `json:"rate_limit_id,omitempty"` } // BudgetAndRateLimitStatus represents the current budget and rate limit usage state @@ -92,17 +97,25 @@ type BudgetAndRateLimitStatus struct { // - This contract ensures consistent behavior across implementations (e.g., in-memory, // DB-backed) and prevents retry loops on policy violations. type GovernanceStore interface { - GetGovernanceData() *GovernanceData - GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) + GetGovernanceData(ctx context.Context) *GovernanceData + GetVirtualKey(ctx context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) + // Budget crud + LoadBudget(ctx context.Context, budgetID string) *configstoreTables.TableBudget + StoreBudget(ctx context.Context, budgetID string, budget *configstoreTables.TableBudget) + DeleteBudget(ctx context.Context, budgetID string) + // Rate limit crud + LoadRateLimit(ctx context.Context, rateLimitID string) *configstoreTables.TableRateLimit + StoreRateLimit(ctx context.Context, rateLimitID string, rateLimit *configstoreTables.TableRateLimit) + DeleteRateLimit(ctx context.Context, rateLimitID string) // Provider-level governance checks - CheckProviderBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) error - CheckProviderRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) + CheckProviderBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) (Decision, error) + CheckProviderRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) // Model-level governance checks - CheckModelBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) error - CheckModelRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) + CheckModelBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) (Decision, error) + CheckModelRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) // VK-level governance checks - CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error - CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) + CheckVirtualKeyBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) (Decision, error) + CheckVirtualKeyRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) // In-memory usage updates (for VK-level) UpdateVirtualKeyBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error UpdateVirtualKeyRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error @@ -119,41 +132,48 @@ type GovernanceStore interface { DumpRateLimits(ctx context.Context, tokenBaselines map[string]int64, requestBaselines map[string]int64) error DumpBudgets(ctx context.Context, baselines map[string]float64) error // In-memory CRUD operations - CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) - UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) - DeleteVirtualKeyInMemory(vkID string) - CreateTeamInMemory(team *configstoreTables.TableTeam) - UpdateTeamInMemory(team *configstoreTables.TableTeam, budgetBaselines map[string]float64) - DeleteTeamInMemory(teamID string) - CreateCustomerInMemory(customer *configstoreTables.TableCustomer) - UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) - DeleteCustomerInMemory(customerID string) + CreateVirtualKeyInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey) + UpdateVirtualKeyInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) + DeleteVirtualKeyInMemory(ctx context.Context, vkID string) + CreateTeamInMemory(ctx context.Context, team *configstoreTables.TableTeam) + UpdateTeamInMemory(ctx context.Context, team *configstoreTables.TableTeam, budgetBaselines map[string]float64) + DeleteTeamInMemory(ctx context.Context, teamID string) + // Customer information + CreateCustomerInMemory(ctx context.Context, customer *configstoreTables.TableCustomer) + UpdateCustomerInMemory(ctx context.Context, customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) + DeleteCustomerInMemory(ctx context.Context, customerID string) + // Team level CheckUserBudget + CheckTeamBudget(ctx context.Context, teamID string, request *EvaluationRequest, baselines map[string]float64) (Decision, error) + CheckTeamRateLimit(ctx context.Context, teamID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) + // Customer-level governance checks + CheckCustomerBudget(ctx context.Context, customerID string, request *EvaluationRequest, baselines map[string]float64) (Decision, error) + CheckCustomerRateLimit(ctx context.Context, customerID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) // User governance in-memory operations (enterprise-only, but interface defined here for compatibility) - GetUserGovernance(userID string) (*UserGovernance, bool) - CreateUserGovernanceInMemory(userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) - UpdateUserGovernanceInMemory(userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) - DeleteUserGovernanceInMemory(userID string) + GetUserGovernance(ctx context.Context, userID string) (*UserGovernance, bool) + CreateUserGovernanceInMemory(ctx context.Context, userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) + UpdateUserGovernanceInMemory(ctx context.Context, userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) + DeleteUserGovernanceInMemory(ctx context.Context, userID string) // User-level governance checks (enterprise-only) - CheckUserBudget(ctx context.Context, userID string, request *EvaluationRequest, baselines map[string]float64) error - CheckUserRateLimit(ctx context.Context, userID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) + CheckUserBudget(ctx context.Context, userID string, request *EvaluationRequest, baselines map[string]float64) (Decision, error) + CheckUserRateLimit(ctx context.Context, userID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) UpdateUserBudgetUsageInMemory(ctx context.Context, userID string, cost float64) error UpdateUserRateLimitUsageInMemory(ctx context.Context, userID string, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error // Model config in-memory operations - UpdateModelConfigInMemory(mc *configstoreTables.TableModelConfig) *configstoreTables.TableModelConfig - DeleteModelConfigInMemory(mcID string) + UpdateModelConfigInMemory(ctx context.Context, mc *configstoreTables.TableModelConfig) *configstoreTables.TableModelConfig + DeleteModelConfigInMemory(ctx context.Context, mcID string) // Provider in-memory operations - UpdateProviderInMemory(provider *configstoreTables.TableProvider) *configstoreTables.TableProvider - DeleteProviderInMemory(providerName string) + UpdateProviderInMemory(ctx context.Context, provider *configstoreTables.TableProvider) *configstoreTables.TableProvider + DeleteProviderInMemory(ctx context.Context, providerName string) // Routing Rules CEL caching - GetRoutingProgram(rule *configstoreTables.TableRoutingRule) (cel.Program, error) + GetRoutingProgram(ctx context.Context, rule *configstoreTables.TableRoutingRule) (cel.Program, error) // Budget and rate limit status queries for routing with baseline support GetBudgetAndRateLimitStatus(ctx context.Context, model string, provider schemas.ModelProvider, vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, tokenBaselines map[string]int64, requestBaselines map[string]int64) *BudgetAndRateLimitStatus // Routing Rules CRUD HasRoutingRules(ctx context.Context) bool - GetAllRoutingRules() []*configstoreTables.TableRoutingRule - GetScopedRoutingRules(scope string, scopeID string) []*configstoreTables.TableRoutingRule - UpdateRoutingRuleInMemory(rule *configstoreTables.TableRoutingRule) error - DeleteRoutingRuleInMemory(id string) error + GetAllRoutingRules(ctx context.Context) []*configstoreTables.TableRoutingRule + GetScopedRoutingRules(ctx context.Context, scope string, scopeID string) []*configstoreTables.TableRoutingRule + UpdateRoutingRuleInMemory(ctx context.Context, rule *configstoreTables.TableRoutingRule) error + DeleteRoutingRuleInMemory(ctx context.Context, id string) error } // NewLocalGovernanceStore creates a new in-memory governance store @@ -191,7 +211,48 @@ func NewLocalGovernanceStore(ctx context.Context, logger schemas.Logger, configS return store, nil } -func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { +// LoadBudget loads a budget by its ID from the local store. +func (gs *LocalGovernanceStore) LoadBudget(ctx context.Context, budgetID string) *configstoreTables.TableBudget { + if budget, ok := gs.budgets.Load(budgetID); ok { + if b, ok := budget.(*configstoreTables.TableBudget); ok { + return b + } + } + return nil +} + +// StoreBudget stores a budget in the local store. +func (gs *LocalGovernanceStore) StoreBudget(ctx context.Context, budgetID string, budget *configstoreTables.TableBudget) { + gs.budgets.Store(budgetID, budget) +} + +// DeleteBudget deletes a budget from the local store. +func (gs *LocalGovernanceStore) DeleteBudget(ctx context.Context, budgetID string) { + gs.budgets.Delete(budgetID) +} + +// LoadRateLimit loads a rate limit by its ID from the local store. +func (gs *LocalGovernanceStore) LoadRateLimit(ctx context.Context, rateLimitID string) *configstoreTables.TableRateLimit { + if rateLimit, ok := gs.rateLimits.Load(rateLimitID); ok { + if rl, ok := rateLimit.(*configstoreTables.TableRateLimit); ok { + return rl + } + } + return nil +} + +// StoreRateLimit stores a rate limit in the local store. +func (gs *LocalGovernanceStore) StoreRateLimit(ctx context.Context, rateLimitID string, rateLimit *configstoreTables.TableRateLimit) { + gs.rateLimits.Store(rateLimitID, rateLimit) +} + +// DeleteRateLimit deletes a rate limit from the local store. +func (gs *LocalGovernanceStore) DeleteRateLimit(ctx context.Context, rateLimitID string) { + gs.rateLimits.Delete(rateLimitID) +} + +// GetGovernanceData returns a snapshot of the current governance data. +func (gs *LocalGovernanceStore) GetGovernanceData(ctx context.Context) *GovernanceData { refreshVKAssociations := func(vk *configstoreTables.TableVirtualKey) { if vk == nil { return @@ -264,7 +325,6 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { } } } - virtualKeys := make(map[string]*configstoreTables.TableVirtualKey) gs.virtualKeys.Range(func(key, value interface{}) bool { vk, ok := value.(*configstoreTables.TableVirtualKey) @@ -316,7 +376,7 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { customers[key.(string)] = &clone return true // continue iteration }) - + // virtualKeys level data for _, vk := range virtualKeys { if vk == nil { continue @@ -337,7 +397,7 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { } } } - + // Team level data for _, team := range teams { if team == nil { continue @@ -352,7 +412,7 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { } } } - + // Customer level data for _, customer := range customers { if customer == nil { continue @@ -370,7 +430,6 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { return customer.VirtualKeys[i].CreatedAt.Before(customer.VirtualKeys[j].CreatedAt) }) } - budgets := make(map[string]*configstoreTables.TableBudget) gs.budgets.Range(func(key, value interface{}) bool { budget, ok := value.(*configstoreTables.TableBudget) @@ -461,37 +520,10 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { sort.Slice(providersList, func(i, j int) bool { return providersList[i].CreatedAt.Before(providersList[j].CreatedAt) }) - // Collect user governance data (enterprise-only) - users := make(map[string]*UserGovernance) - gs.users.Range(func(key, value interface{}) bool { - ug, ok := value.(*UserGovernance) - if !ok || ug == nil { - return true // continue - } - // Cross-reference live budget/rate limit from standalone maps - clone := *ug - if clone.BudgetID != nil { - if liveBudget, exists := gs.budgets.Load(*clone.BudgetID); exists && liveBudget != nil { - if b, ok := liveBudget.(*configstoreTables.TableBudget); ok { - clone.Budget = b - } - } - } - if clone.RateLimitID != nil { - if liveRL, exists := gs.rateLimits.Load(*clone.RateLimitID); exists && liveRL != nil { - if rl, ok := liveRL.(*configstoreTables.TableRateLimit); ok { - clone.RateLimit = rl - } - } - } - users[key.(string)] = &clone - return true // continue iteration - }) return &GovernanceData{ VirtualKeys: virtualKeys, Teams: teams, Customers: customers, - Users: users, Budgets: budgets, RateLimits: rateLimits, RoutingRules: routingRules, @@ -501,12 +533,11 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { } // GetVirtualKey retrieves a virtual key by its value (lock-free) with all relationships preloaded -func (gs *LocalGovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { +func (gs *LocalGovernanceStore) GetVirtualKey(ctx context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) { value, exists := gs.virtualKeys.Load(vkValue) if !exists || value == nil { return nil, false } - vk, ok := value.(*configstoreTables.TableVirtualKey) if !ok || vk == nil { return nil, false @@ -514,246 +545,197 @@ func (gs *LocalGovernanceStore) GetVirtualKey(vkValue string) (*configstoreTable return vk, true } -// CheckBudget performs budget checking using in-memory store data (lock-free for high performance) -func (gs *LocalGovernanceStore) CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error { - if vk == nil { - return fmt.Errorf("virtual key cannot be nil") +// CheckRateLimit checks rate limits for tokens and requests across categories +func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, entityWiseRateLimits EntityWiseRateLimits, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { + for entity, rateLimits := range entityWiseRateLimits { + for _, rateLimit := range rateLimits { + var violations []string + // Check if rate limit needs reset (in-memory check) + // Track which limits are expired so we can skip only those specific checks + tokenLimitExpired := false + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if time.Since(rateLimit.TokenLastReset) >= duration { + // Token rate limit expired but hasn't been reset yet - skip token check only + tokenLimitExpired = true + } + } + } + requestLimitExpired := false + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if time.Since(rateLimit.RequestLastReset) >= duration { + // Request rate limit expired but hasn't been reset yet - skip request check only + requestLimitExpired = true + } + } + } + + tokensBaseline, exists := tokensBaselines[rateLimit.ID] + if !exists { + tokensBaseline = 0 + } + requestsBaseline, exists := requestsBaselines[rateLimit.ID] + if !exists { + requestsBaseline = 0 + } + + // Token limits - check if total usage (local + remote baseline) exceeds limit + // Skip this check if token limit has expired + if !tokenLimitExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { + duration := "unknown" + if rateLimit.TokenResetDuration != nil { + duration = *rateLimit.TokenResetDuration + } + violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", + rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) + } + + // Request limits - check if total usage (local + remote baseline) exceeds limit + // Skip this check if request limit has expired + if !requestLimitExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { + duration := "unknown" + if rateLimit.RequestResetDuration != nil { + duration = *rateLimit.RequestResetDuration + } + violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", + rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) + } + + if len(violations) > 0 { + // Determine specific violation type + decision := DecisionRateLimited // Default to general rate limited decision + if len(violations) == 1 { + if strings.Contains(violations[0], "token") { + decision = DecisionTokenLimited // More specific violation type + } else if strings.Contains(violations[0], "request") { + decision = DecisionRequestLimited // More specific violation type + } + } + return decision, fmt.Errorf("rate limit violated for %s: %s", entity, violations) + } + } } + return DecisionAllow, nil +} +// Generic check budget method +// The idea is to keep this as a common method for checking all budgets. The entire business logic resides in here +func (gs *LocalGovernanceStore) CheckBudget(ctx context.Context, entityWiseBudgets EntityWiseBudgets, baselines map[string]float64) (Decision, error) { + // Check each budget in hierarchy order using in-memory data + for entity, budgets := range entityWiseBudgets { + for _, budget := range budgets { // Check if budget needs reset (in-memory check) + if budget.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { + if time.Since(budget.LastReset) >= duration { + // Budget expired but hasn't been reset yet - treat as reset + // Note: actual reset will happen in post-hook via AtomicBudgetUpdate + gs.logger.Debug("LocalStore CheckBudget: Budget %s (%s) expired, skipping check", budget.ID, entity) + continue // Skip budget check for expired budgets + } + } + } + baseline, exists := baselines[budget.ID] + if !exists { + baseline = 0 + } + gs.logger.Debug("LocalStore CheckBudget: Checking %s budget %s: local=%.4f, remote=%.4f, total=%.4f, limit=%.4f", + entity, budget.ID, budget.CurrentUsage, baseline, budget.CurrentUsage+baseline, budget.MaxLimit) + // Check if current usage (local + remote baseline) exceeds budget limit + if budget.CurrentUsage+baseline >= budget.MaxLimit { + gs.logger.Debug("LocalStore CheckBudget: Budget %s EXCEEDED", budget.ID) + return DecisionBudgetExceeded, fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", + entity, budget.CurrentUsage+baseline, budget.MaxLimit) + } + } + } + return DecisionAllow, nil +} + +// CheckVirtualKeyBudget performs virtual key level budget checking using in-memory store data (lock-free for high performance) +func (gs *LocalGovernanceStore) CheckVirtualKeyBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) (Decision, error) { + if vk == nil { + return DecisionVirtualKeyNotFound, fmt.Errorf("virtual key cannot be nil") + } // This is to prevent nil pointer dereference if baselines == nil { baselines = map[string]float64{} } - // Extract provider from request var provider schemas.ModelProvider if request != nil { provider = request.Provider } - // Use helper to collect budgets and their names (lock-free) - budgetsToCheck, budgetNames := gs.collectBudgetsFromHierarchy(vk, provider) - + budgetsWithCategories := gs.collectBudgetsFromHierarchy(ctx, vk, provider) gs.logger.Debug("LocalStore CheckBudget: Received %d baselines from remote nodes", len(baselines)) for budgetID, baseline := range baselines { gs.logger.Debug(" - Baseline for budget %s: %.4f", budgetID, baseline) } - - // Check each budget in hierarchy order using in-memory data - for i, budget := range budgetsToCheck { - // Check if budget needs reset (in-memory check) - if budget.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { - if time.Since(budget.LastReset) >= duration { - // Budget expired but hasn't been reset yet - treat as reset - // Note: actual reset will happen in post-hook via AtomicBudgetUpdate - gs.logger.Debug("LocalStore CheckBudget: Budget %s (%s) expired, skipping check", budget.ID, budgetNames[i]) - continue // Skip budget check for expired budgets - } - } - } - - baseline, exists := baselines[budget.ID] - if !exists { - baseline = 0 - } - - gs.logger.Debug("LocalStore CheckBudget: Checking %s budget %s: local=%.4f, remote=%.4f, total=%.4f, limit=%.4f", - budgetNames[i], budget.ID, budget.CurrentUsage, baseline, budget.CurrentUsage+baseline, budget.MaxLimit) - - // Check if current usage (local + remote baseline) exceeds budget limit - if budget.CurrentUsage+baseline >= budget.MaxLimit { - gs.logger.Debug("LocalStore CheckBudget: Budget %s EXCEEDED", budget.ID) - return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", - budgetNames[i], budget.CurrentUsage+baseline, budget.MaxLimit) - } - } - - gs.logger.Debug("LocalStore CheckBudget: All budgets passed") - - return nil + return gs.CheckBudget(ctx, budgetsWithCategories, baselines) } // CheckProviderBudget performs budget checking for provider-level configs (lock-free for high performance) -func (gs *LocalGovernanceStore) CheckProviderBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) error { +func (gs *LocalGovernanceStore) CheckProviderBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) (Decision, error) { // This is to prevent nil pointer dereference if baselines == nil { baselines = map[string]float64{} } - // Extract provider from request var provider schemas.ModelProvider if request != nil { provider = request.Provider } - // Get provider config providerKey := string(provider) value, exists := gs.providers.Load(providerKey) if !exists || value == nil { // No provider config found, allow request - return nil + return DecisionAllow, nil } - providerTable, ok := value.(*configstoreTables.TableProvider) if !ok || providerTable == nil || providerTable.BudgetID == nil { // No budget configured for provider, allow request - return nil + return DecisionAllow, nil } - // Read from budgets map to get the latest updated budget (same source as UpdateProviderBudgetUsage) - budgetValue, exists := gs.budgets.Load(*providerTable.BudgetID) - if !exists || budgetValue == nil { - // Budget not found in cache, allow request - return nil - } - - budget, ok := budgetValue.(*configstoreTables.TableBudget) - if !ok || budget == nil { - // Invalid budget type, allow request - return nil - } - - // Check if budget needs reset (in-memory check) - if budget.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { - if time.Since(budget.LastReset) >= duration { - // Budget expired but hasn't been reset yet - treat as reset - return nil // Skip budget check for expired budgets - } - } - } - - baseline, exists := baselines[budget.ID] - if !exists { - baseline = 0 - } - - // Check if current usage (local + remote baseline) exceeds budget limit - if budget.CurrentUsage+baseline >= budget.MaxLimit { - return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", - providerKey, budget.CurrentUsage+baseline, budget.MaxLimit) + budget := gs.LoadBudget(ctx, *providerTable.BudgetID) + if budget == nil { + return DecisionAllow, nil } - - return nil + return gs.CheckBudget(ctx, map[string][]*configstoreTables.TableBudget{providerKey: {budget}}, baselines) } // CheckProviderRateLimit checks provider-level rate limits and returns evaluation result if violated -func (gs *LocalGovernanceStore) CheckProviderRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) { - var violations []string - - // This is to prevent nil pointer dereference - if tokensBaselines == nil { - tokensBaselines = map[string]int64{} - } - if requestsBaselines == nil { - requestsBaselines = map[string]int64{} - } - +func (gs *LocalGovernanceStore) CheckProviderRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { // Extract provider from request var provider schemas.ModelProvider if request != nil { provider = request.Provider } - // Get provider config providerKey := string(provider) value, exists := gs.providers.Load(providerKey) if !exists || value == nil { // No provider config found, allow request - return nil, DecisionAllow + return DecisionAllow, nil } - providerTable, ok := value.(*configstoreTables.TableProvider) if !ok || providerTable == nil || providerTable.RateLimitID == nil { // No rate limit configured for provider, allow request - return nil, DecisionAllow + return DecisionAllow, nil } - // Read from rateLimits map to get the latest updated rate limit (same source as UpdateProviderRateLimitUsage) - rateLimitValue, exists := gs.rateLimits.Load(*providerTable.RateLimitID) - if !exists || rateLimitValue == nil { - // Rate limit not found in cache, allow request - return nil, DecisionAllow - } - - rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit) - if !ok || rateLimit == nil { - // Invalid rate limit type, allow request - return nil, DecisionAllow - } - - // Check if rate limit needs reset (in-memory check) - // Track which limits are expired so we can skip only those specific checks - tokenLimitExpired := false - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if time.Since(rateLimit.TokenLastReset) >= duration { - // Token rate limit expired but hasn't been reset yet - skip token check only - tokenLimitExpired = true - } - } - } - requestLimitExpired := false - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if time.Since(rateLimit.RequestLastReset) >= duration { - // Request rate limit expired but hasn't been reset yet - skip request check only - requestLimitExpired = true - } - } + rateLimit := gs.LoadRateLimit(ctx, *providerTable.RateLimitID) + if rateLimit == nil { + return DecisionAllow, nil } - - tokensBaseline, exists := tokensBaselines[rateLimit.ID] - if !exists { - tokensBaseline = 0 - } - requestsBaseline, exists := requestsBaselines[rateLimit.ID] - if !exists { - requestsBaseline = 0 - } - - // Token limits - check if total usage (local + remote baseline) exceeds limit - // Skip this check if token limit has expired - if !tokenLimitExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { - duration := "unknown" - if rateLimit.TokenResetDuration != nil { - duration = *rateLimit.TokenResetDuration - } - violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", - rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) - } - - // Request limits - check if total usage (local + remote baseline) exceeds limit - // Skip this check if request limit has expired - if !requestLimitExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { - duration := "unknown" - if rateLimit.RequestResetDuration != nil { - duration = *rateLimit.RequestResetDuration - } - violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", - rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) - } - - if len(violations) > 0 { - // Determine specific violation type - decision := DecisionRateLimited // Default to general rate limited decision - if len(violations) == 1 { - if strings.Contains(violations[0], "token") { - decision = DecisionTokenLimited // More specific violation type - } else if strings.Contains(violations[0], "request") { - decision = DecisionRequestLimited // More specific violation type - } - } - return fmt.Errorf("rate limit violated for %s: %s", providerKey, violations), decision - } - - return nil, DecisionAllow // No rate limit violations + return gs.CheckRateLimit(ctx, EntityWiseRateLimits{providerKey: []*configstoreTables.TableRateLimit{rateLimit}}, tokensBaselines, requestsBaselines) } // findModelOnlyConfig looks up a model-only config (no provider) with cross-provider model name normalization. // Returns the matching config and the display name for error messages. -func (gs *LocalGovernanceStore) findModelOnlyConfig(model string) (*configstoreTables.TableModelConfig, string) { +func (gs *LocalGovernanceStore) findModelOnlyConfig(ctx context.Context, model string) (*configstoreTables.TableModelConfig, string) { // If modelMatcher is available, try normalized base model name first (cross-provider matching) if gs.modelCatalog != nil { baseName := gs.modelCatalog.GetBaseModelName(model) @@ -765,24 +747,21 @@ func (gs *LocalGovernanceStore) findModelOnlyConfig(model string) (*configstoreT } } } - // Always try direct lookup by original model name as fallback if value, exists := gs.modelConfigs.Load(model); exists && value != nil { if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil { return mc, model } } - return nil, "" } // CheckModelBudget performs budget checking for model-level configs (lock-free for high performance) -func (gs *LocalGovernanceStore) CheckModelBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) error { +func (gs *LocalGovernanceStore) CheckModelBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) (Decision, error) { // This is to prevent nil pointer dereference if baselines == nil { baselines = map[string]float64{} } - // Extract model and provider from request var model string var provider *schemas.ModelProvider @@ -792,348 +771,200 @@ func (gs *LocalGovernanceStore) CheckModelBudget(ctx context.Context, request *E provider = &request.Provider } } - // Collect model configs to check: model+provider (if exists) AND model-only (if exists) - var modelConfigsToCheck []*configstoreTables.TableModelConfig - var budgetNames []string - + entityWiseBudgets := EntityWiseBudgets{} // Check model+provider config first (more specific) - if provider is provided if provider != nil { key := fmt.Sprintf("%s:%s", model, string(*provider)) if value, exists := gs.modelConfigs.Load(key); exists && value != nil { if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.Budget != nil { - modelConfigsToCheck = append(modelConfigsToCheck, mc) - budgetNames = append(budgetNames, fmt.Sprintf("Model:%s:Provider:%s", model, string(*provider))) + budget := gs.LoadBudget(ctx, *mc.BudgetID) + if budget != nil { + key := fmt.Sprintf("Model:%s:Provider:%s", mc.ModelName, *provider) + entityWiseBudgets[key] = []*configstoreTables.TableBudget{budget} + } } } } - // Always check model-only config (if exists) - regardless of whether model+provider config exists // Uses findModelOnlyConfig for cross-provider model name normalization - if mc, configKey := gs.findModelOnlyConfig(model); mc != nil && mc.Budget != nil { - modelConfigsToCheck = append(modelConfigsToCheck, mc) - budgetNames = append(budgetNames, fmt.Sprintf("Model:%s", configKey)) - } - - // Check each model budget - for i, mc := range modelConfigsToCheck { - if mc.BudgetID == nil { - continue - } - - // Read from budgets map to get the latest updated budget (same source as UpdateModelBudgetUsage) - budgetValue, exists := gs.budgets.Load(*mc.BudgetID) - if !exists || budgetValue == nil { - // Budget not found in cache, skip check - continue - } - - budget, ok := budgetValue.(*configstoreTables.TableBudget) - if !ok || budget == nil { - // Invalid budget type, skip check - continue - } - - // Check if budget needs reset (in-memory check) - if budget.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { - if time.Since(budget.LastReset) >= duration { - // Budget expired but hasn't been reset yet - treat as reset - continue // Skip budget check for expired budgets - } - } - } - - baseline, exists := baselines[budget.ID] - if !exists { - baseline = 0 - } - - // Check if current usage (local + remote baseline) exceeds budget limit - if budget.CurrentUsage+baseline >= budget.MaxLimit { - return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", - budgetNames[i], budget.CurrentUsage+baseline, budget.MaxLimit) + if mc, _ := gs.findModelOnlyConfig(ctx, model); mc != nil && mc.Budget != nil { + budget := gs.LoadBudget(ctx, *mc.BudgetID) + if budget != nil { + key := fmt.Sprintf("Model:%s", mc.ModelName) + entityWiseBudgets[key] = []*configstoreTables.TableBudget{budget} } } - - return nil + return gs.CheckBudget(ctx, entityWiseBudgets, baselines) } -// CheckUserBudget checks if user's budget allows the request (enterprise-only) -func (gs *LocalGovernanceStore) CheckUserBudget(ctx context.Context, userID string, request *EvaluationRequest, baselines map[string]float64) error { - if userID == "" { - return nil // No user, skip check +// CheckTeamBudget checks team-level budget and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckTeamBudget(ctx context.Context, teamID string, request *EvaluationRequest, baselines map[string]float64) (Decision, error) { + if teamID == "" { + return DecisionAllow, nil } - if baselines == nil { baselines = map[string]float64{} } - - ug, exists := gs.GetUserGovernance(userID) - if !exists || ug == nil || ug.BudgetID == nil { - return nil // No budget configured for user + teamValue, exists := gs.teams.Load(teamID) + if !exists || teamValue == nil { + return DecisionAllow, nil } - - budgetValue, exists := gs.budgets.Load(*ug.BudgetID) - if !exists || budgetValue == nil { - return nil + team, ok := teamValue.(*configstoreTables.TableTeam) + if !ok || team.BudgetID == nil { + return DecisionAllow, nil } - - budget, ok := budgetValue.(*configstoreTables.TableBudget) - if !ok || budget == nil { - return nil - } - - // Check if budget needs reset - if budget.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { - if time.Since(budget.LastReset) >= duration { - return nil // Budget expired, skip check - } - } + teamBudget := gs.LoadBudget(ctx, *team.BudgetID) + if teamBudget == nil { + return DecisionAllow, nil } - - baseline := baselines[budget.ID] - if budget.CurrentUsage+baseline >= budget.MaxLimit { - return fmt.Errorf("user budget exceeded: %.4f >= %.4f dollars", budget.CurrentUsage+baseline, budget.MaxLimit) - } - - return nil + key := fmt.Sprintf("Team:%s", teamID) + entityWiseBudgets := EntityWiseBudgets{key: {teamBudget}} + return gs.CheckBudget(ctx, entityWiseBudgets, baselines) } -// CheckModelRateLimit checks model-level rate limits and returns evaluation result if violated -func (gs *LocalGovernanceStore) CheckModelRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) { - var violations []string - - // This is to prevent nil pointer dereference +// CheckTeamRateLimit checks team-level rate limit and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckTeamRateLimit(ctx context.Context, teamID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { if tokensBaselines == nil { tokensBaselines = map[string]int64{} } if requestsBaselines == nil { requestsBaselines = map[string]int64{} } - - // Extract model and provider from request - var model string - var provider *schemas.ModelProvider - if request != nil { - model = request.Model - if request.Provider != "" { - provider = &request.Provider - } + teamValue, exists := gs.teams.Load(teamID) + if !exists || teamValue == nil { + return DecisionAllow, nil } - - // Collect model configs to check: model+provider (if exists) AND model-only (if exists) - var modelConfigsToCheck []*configstoreTables.TableModelConfig - var rateLimitNames []string - - // Check model+provider config first (more specific) - if provider is provided - if provider != nil { - key := fmt.Sprintf("%s:%s", model, string(*provider)) - if value, exists := gs.modelConfigs.Load(key); exists && value != nil { - if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { - modelConfigsToCheck = append(modelConfigsToCheck, mc) - rateLimitNames = append(rateLimitNames, fmt.Sprintf("Model:%s:Provider:%s", model, string(*provider))) - } - } + team, ok := teamValue.(*configstoreTables.TableTeam) + if !ok || team.RateLimitID == nil { + return DecisionAllow, nil } - - // Always check model-only config (if exists) - regardless of whether model+provider config exists - // Uses findModelOnlyConfig for cross-provider model name normalization - if mc, configKey := gs.findModelOnlyConfig(model); mc != nil && mc.RateLimitID != nil { - modelConfigsToCheck = append(modelConfigsToCheck, mc) - rateLimitNames = append(rateLimitNames, fmt.Sprintf("Model:%s", configKey)) + teamRateLimit := gs.LoadRateLimit(ctx, *team.RateLimitID) + if teamRateLimit == nil { + return DecisionAllow, nil } + key := fmt.Sprintf("Team:%s", teamID) + entityWiseRateLimits := EntityWiseRateLimits{key: {teamRateLimit}} + return gs.CheckRateLimit(ctx, entityWiseRateLimits, tokensBaselines, requestsBaselines) +} - // Check each model rate limit - for i, mc := range modelConfigsToCheck { - if mc.RateLimitID == nil { - continue - } - - // Read from rateLimits map to get the latest updated rate limit (same source as UpdateModelRateLimitUsage) - rateLimitValue, exists := gs.rateLimits.Load(*mc.RateLimitID) - if !exists || rateLimitValue == nil { - // Rate limit not found in cache, skip check - continue - } - - rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit) - if !ok || rateLimit == nil { - // Invalid rate limit type, skip check - continue - } - - // Check if rate limit needs reset (in-memory check) - // Track which limits are expired so we can skip only those specific checks - tokenLimitExpired := false - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if time.Since(rateLimit.TokenLastReset) >= duration { - // Token rate limit expired but hasn't been reset yet - skip token check only - tokenLimitExpired = true - } - } - } - requestLimitExpired := false - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if time.Since(rateLimit.RequestLastReset) >= duration { - // Request rate limit expired but hasn't been reset yet - skip request check only - requestLimitExpired = true - } - } - } - - tokensBaseline, exists := tokensBaselines[rateLimit.ID] - if !exists { - tokensBaseline = 0 - } - requestsBaseline, exists := requestsBaselines[rateLimit.ID] - if !exists { - requestsBaseline = 0 - } - - // Token limits - check if total usage (local + remote baseline) exceeds limit - // Skip this check if token limit has expired - if !tokenLimitExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { - duration := "unknown" - if rateLimit.TokenResetDuration != nil { - duration = *rateLimit.TokenResetDuration - } - violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", - rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) - } - - // Request limits - check if total usage (local + remote baseline) exceeds limit - // Skip this check if request limit has expired - if !requestLimitExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { - duration := "unknown" - if rateLimit.RequestResetDuration != nil { - duration = *rateLimit.RequestResetDuration - } - violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", - rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) - } - - if len(violations) > 0 { - // Determine specific violation type - decision := DecisionRateLimited // Default to general rate limited decision - if len(violations) == 1 { - if strings.Contains(violations[0], "token") { - decision = DecisionTokenLimited // More specific violation type - } else if strings.Contains(violations[0], "request") { - decision = DecisionRequestLimited // More specific violation type - } - } - return fmt.Errorf("rate limit violated for %s: %s", rateLimitNames[i], violations), decision - } +// CheckCustomerBudget checks customer-level budget and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckCustomerBudget(ctx context.Context, customerID string, request *EvaluationRequest, baselines map[string]float64) (Decision, error) { + if customerID == "" { + return DecisionAllow, nil } - - return nil, DecisionAllow // No rate limit violations + if baselines == nil { + baselines = map[string]float64{} + } + customerValue, exists := gs.customers.Load(customerID) + if !exists || customerValue == nil { + return DecisionAllow, nil + } + customer, ok := customerValue.(*configstoreTables.TableCustomer) + if !ok || customer.BudgetID == nil { + return DecisionAllow, nil + } + customerBudget := gs.LoadBudget(ctx, *customer.BudgetID) + if customerBudget == nil { + return DecisionAllow, nil + } + key := fmt.Sprintf("Customer:%s", customerID) + entityWiseBudgets := EntityWiseBudgets{key: {customerBudget}} + return gs.CheckBudget(ctx, entityWiseBudgets, baselines) } -// CheckUserRateLimit checks if user's rate limit allows the request (enterprise-only) -func (gs *LocalGovernanceStore) CheckUserRateLimit(ctx context.Context, userID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) { - if userID == "" { - return nil, DecisionAllow // No user, skip check +// CheckCustomerRateLimit checks customer-level rate limit and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckCustomerRateLimit(ctx context.Context, customerID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { + if customerID == "" { + return DecisionAllow, nil } - if tokensBaselines == nil { tokensBaselines = map[string]int64{} } if requestsBaselines == nil { requestsBaselines = map[string]int64{} } - - ug, exists := gs.GetUserGovernance(userID) - if !exists || ug == nil || ug.RateLimitID == nil { - return nil, DecisionAllow // No rate limit configured for user + customerValue, exists := gs.customers.Load(customerID) + if !exists || customerValue == nil { + return DecisionAllow, nil } - - rateLimitValue, exists := gs.rateLimits.Load(*ug.RateLimitID) - if !exists || rateLimitValue == nil { - return nil, DecisionAllow + customer, ok := customerValue.(*configstoreTables.TableCustomer) + if !ok || customer.RateLimitID == nil { + return DecisionAllow, nil } - - rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit) - if !ok || rateLimit == nil { - return nil, DecisionAllow + customerRateLimit := gs.LoadRateLimit(ctx, *customer.RateLimitID) + if customerRateLimit == nil { + return DecisionAllow, nil } + key := fmt.Sprintf("Customer:%s", customerID) + entityWiseRateLimits := EntityWiseRateLimits{key: {customerRateLimit}} + return gs.CheckRateLimit(ctx, entityWiseRateLimits, tokensBaselines, requestsBaselines) +} - var violations []string +// CheckUserBudget checks if user's budget allows the request (enterprise-only) +// Community build: silent no-op so user-governance absence never silently denies requests. +func (gs *LocalGovernanceStore) CheckUserBudget(ctx context.Context, userID string, request *EvaluationRequest, baselines map[string]float64) (Decision, error) { + return DecisionAllow, nil +} - // Check token limit expiry - tokenLimitExpired := false - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if time.Since(rateLimit.TokenLastReset) >= duration { - tokenLimitExpired = true - } - } +// CheckModelRateLimit checks model-level rate limits and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckModelRateLimit(ctx context.Context, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { + // This is to prevent nil pointer dereference + if tokensBaselines == nil { + tokensBaselines = map[string]int64{} } - - // Check request limit expiry - requestLimitExpired := false - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if time.Since(rateLimit.RequestLastReset) >= duration { - requestLimitExpired = true - } - } + if requestsBaselines == nil { + requestsBaselines = map[string]int64{} } - - tokensBaseline := tokensBaselines[rateLimit.ID] - requestsBaseline := requestsBaselines[rateLimit.ID] - - // Check token limit - if !tokenLimitExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { - duration := "unknown" - if rateLimit.TokenResetDuration != nil { - duration = *rateLimit.TokenResetDuration + // Extract model and provider from request + var model string + var provider *schemas.ModelProvider + if request != nil { + model = request.Model + if request.Provider != "" { + provider = &request.Provider } - violations = append(violations, fmt.Sprintf("user token limit exceeded (%d/%d, resets every %s)", - rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) } - - // Check request limit - if !requestLimitExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { - duration := "unknown" - if rateLimit.RequestResetDuration != nil { - duration = *rateLimit.RequestResetDuration + // Collect model configs to check: model+provider (if exists) AND model-only (if exists) + entityWiseRateLimits := make(EntityWiseRateLimits) + // Check model+provider config first (more specific) - if provider is provided + if provider != nil { + key := fmt.Sprintf("%s:%s", model, string(*provider)) + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { + rateLimit := gs.LoadRateLimit(ctx, *mc.RateLimitID) + if rateLimit != nil { + entityWiseRateLimits[fmt.Sprintf("Model:%s:Provider:%s", model, string(*provider))] = []*configstoreTables.TableRateLimit{rateLimit} + } + } } - violations = append(violations, fmt.Sprintf("user request limit exceeded (%d/%d, resets every %s)", - rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) } - - if len(violations) > 0 { - decision := DecisionRateLimited - if len(violations) == 1 { - if strings.Contains(violations[0], "token") { - decision = DecisionTokenLimited - } else if strings.Contains(violations[0], "request") { - decision = DecisionRequestLimited - } + // Always check model-only config (if exists) - regardless of whether model+provider config exists + // Uses findModelOnlyConfig for cross-provider model name normalization + if mc, configKey := gs.findModelOnlyConfig(ctx, model); mc != nil && mc.RateLimitID != nil { + rateLimit := gs.LoadRateLimit(ctx, *mc.RateLimitID) + if rateLimit != nil { + entityWiseRateLimits[fmt.Sprintf("Model:%s", configKey)] = []*configstoreTables.TableRateLimit{rateLimit} } - return fmt.Errorf("user rate limit violated: %s", strings.Join(violations, ", ")), decision } - - return nil, DecisionAllow + return gs.CheckRateLimit(ctx, entityWiseRateLimits, tokensBaselines, requestsBaselines) } -// CheckRateLimit checks a single rate limit and returns evaluation result if violated (true if violated, false if not) -func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { - var violations []string +// CheckUserRateLimit checks if user's rate limit allows the request (enterprise-only) +// Community build: silent no-op so user-governance absence never silently denies requests. +func (gs *LocalGovernanceStore) CheckUserRateLimit(ctx context.Context, userID string, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { + return DecisionAllow, nil +} +// CheckVirtualKeyRateLimit checks a virtual key rate limit and returns evaluation result if violated (true if violated, false if not) +func (gs *LocalGovernanceStore) CheckVirtualKeyRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { // Extract provider from request var provider schemas.ModelProvider if request != nil { provider = request.Provider } - // Collect rate limits and their names from the hierarchy - rateLimits, rateLimitNames := gs.collectRateLimitsFromHierarchy(vk, provider) - + entityWiseRateLimits := gs.collectRateLimitsFromHierarchy(ctx, vk, provider) // This is to prevent nil pointer dereference if tokensBaselines == nil { tokensBaselines = map[string]int64{} @@ -1141,81 +972,7 @@ func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, vk *configst if requestsBaselines == nil { requestsBaselines = map[string]int64{} } - - for i, rateLimit := range rateLimits { - // Determine token and request expiration independently - tokenExpired := false - requestExpired := false - - // Check if token reset duration is expired - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if time.Since(rateLimit.TokenLastReset) >= duration { - // Token rate limit expired but hasn't been reset yet - skip token checks - // Note: actual reset will happen in post-hook via AtomicRateLimitUpdate - tokenExpired = true - } - } - } - - // Check if request reset duration is expired - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if time.Since(rateLimit.RequestLastReset) >= duration { - // Request rate limit expired but hasn't been reset yet - skip request checks - // Note: actual reset will happen in post-hook via AtomicRateLimitUpdate - requestExpired = true - } - } - } - - tokensBaseline, exists := tokensBaselines[rateLimit.ID] - if !exists { - tokensBaseline = 0 - } - requestsBaseline, exists := requestsBaselines[rateLimit.ID] - if !exists { - requestsBaseline = 0 - } - - // Token limits - check if total usage (local + remote baseline) exceeds limit - // Only check if token limit is not expired - if !tokenExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { - duration := "unknown" - if rateLimit.TokenResetDuration != nil { - duration = *rateLimit.TokenResetDuration - } - violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", - rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) - } - - // Request limits - check if total usage (local + remote baseline) exceeds limit - // Only check if request limit is not expired - if !requestExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { - duration := "unknown" - if rateLimit.RequestResetDuration != nil { - duration = *rateLimit.RequestResetDuration - } - violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", - rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) - } - - if len(violations) > 0 { - // Determine specific violation type - decision := DecisionRateLimited // Default to general rate limited decision - if len(violations) == 1 { - if strings.Contains(violations[0], "token") { - decision = DecisionTokenLimited // More specific violation type - } else if strings.Contains(violations[0], "request") { - decision = DecisionRequestLimited // More specific violation type - } - } - msg := strings.Join(violations, "; ") - return decision, fmt.Errorf("rate limit violated for %s: %s", rateLimitNames[i], msg) - } - } - - return DecisionAllow, nil // No rate limit violations + return gs.CheckRateLimit(ctx, entityWiseRateLimits, tokensBaselines, requestsBaselines) } // UpdateVirtualKeyBudgetUsageInMemory performs atomic budget updates across the hierarchy (both in memory and in database) @@ -1308,7 +1065,7 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelBudgetUsageInMemory(ctx co // Always check model-only config (if exists) - regardless of whether model+provider config exists // Uses findModelOnlyConfig for cross-provider model name normalization - if mc, _ := gs.findModelOnlyConfig(model); mc != nil && mc.BudgetID != nil { + if mc, _ := gs.findModelOnlyConfig(ctx, model); mc != nil && mc.BudgetID != nil { updateBudget(*mc.BudgetID) } @@ -1316,42 +1073,8 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelBudgetUsageInMemory(ctx co } // UpdateUserBudgetUsageInMemory updates user's budget usage in memory (enterprise-only) +// Community build: silent no-op to avoid per-request error spam when a userID is set. func (gs *LocalGovernanceStore) UpdateUserBudgetUsageInMemory(ctx context.Context, userID string, cost float64) error { - if userID == "" || cost <= 0 { - return nil - } - - ug, exists := gs.GetUserGovernance(userID) - if !exists || ug == nil || ug.BudgetID == nil { - return nil - } - - budgetValue, exists := gs.budgets.Load(*ug.BudgetID) - if !exists || budgetValue == nil { - return nil - } - - budget, ok := budgetValue.(*configstoreTables.TableBudget) - if !ok || budget == nil { - return nil - } - - // Clone FIRST to avoid race conditions - now := time.Now() - clone := *budget - // Check if budget needs reset (in-memory check) - operate on clone - if clone.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { - if now.Sub(clone.LastReset) >= duration { - clone.CurrentUsage = 0 - clone.LastReset = now - } - } - } - // Update the clone - clone.CurrentUsage += cost - gs.budgets.Store(clone.ID, &clone) - return nil } @@ -1417,7 +1140,7 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelRateLimitUsageInMemory(ctx // Always check model-only config (if exists) - regardless of whether model+provider config exists // Uses findModelOnlyConfig for cross-provider model name normalization - if mc, _ := gs.findModelOnlyConfig(model); mc != nil && mc.RateLimitID != nil { + if mc, _ := gs.findModelOnlyConfig(ctx, model); mc != nil && mc.RateLimitID != nil { updateRateLimit(*mc.RateLimitID) } @@ -1429,11 +1152,9 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyRateLimitUsageInMemory(ctx conte if vk == nil { return fmt.Errorf("virtual key cannot be nil") } - // Collect rate limit IDs using fast in-memory lookup instead of DB queries - rateLimitIDs := gs.collectRateLimitIDsFromMemory(vk, provider) + rateLimitIDs := gs.collectRateLimitIDsFromMemory(ctx, vk, provider) now := time.Now() - for _, rateLimitID := range rateLimitIDs { // Update in-memory cache for next read (lock-free) if cachedRateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && cachedRateLimitValue != nil { @@ -1460,7 +1181,6 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyRateLimitUsageInMemory(ctx conte } } } - // Update the clone if shouldUpdateTokens { clone.TokenCurrentUsage += tokensUsed @@ -1476,55 +1196,8 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyRateLimitUsageInMemory(ctx conte } // UpdateUserRateLimitUsageInMemory updates user's rate limit usage in memory (enterprise-only) +// Community build: silent no-op to avoid per-request error spam when a userID is set. func (gs *LocalGovernanceStore) UpdateUserRateLimitUsageInMemory(ctx context.Context, userID string, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { - if userID == "" { - return nil - } - - ug, exists := gs.GetUserGovernance(userID) - if !exists || ug == nil || ug.RateLimitID == nil { - return nil - } - - rateLimitValue, exists := gs.rateLimits.Load(*ug.RateLimitID) - if !exists || rateLimitValue == nil { - return nil - } - - rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit) - if !ok || rateLimit == nil { - return nil - } - - // Clone FIRST to avoid race conditions - now := time.Now() - clone := *rateLimit - // Check if rate limit needs reset (in-memory check) - operate on clone - if clone.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { - if now.Sub(clone.TokenLastReset) >= duration { - clone.TokenCurrentUsage = 0 - clone.TokenLastReset = now - } - } - } - if clone.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { - if now.Sub(clone.RequestLastReset) >= duration { - clone.RequestCurrentUsage = 0 - clone.RequestLastReset = now - } - } - } - // Update the clone - if shouldUpdateTokens { - clone.TokenCurrentUsage += tokensUsed - } - if shouldUpdateRequests { - clone.RequestCurrentUsage++ - } - gs.rateLimits.Store(clone.ID, &clone) - return nil } @@ -1532,45 +1205,18 @@ func (gs *LocalGovernanceStore) UpdateUserRateLimitUsageInMemory(ctx context.Con func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget { now := time.Now() var resetBudgets []*configstoreTables.TableBudget - - gs.budgets.Range(func(key, value interface{}) bool { + // We reset all budgets + gs.budgets.Range(func(key, value any) bool { // Type-safe conversion budget, ok := value.(*configstoreTables.TableBudget) if !ok || budget == nil { return true // continue } - // Determine whether the budget needs resetting var shouldReset bool var newLastReset time.Time - - // Check if the owning VK has calendar alignment enabled - // virtualKeys map is keyed by VK value (not ID), so we scan to find by VirtualKeyID - calendarAligned := false - if budget.VirtualKeyID != nil { - gs.virtualKeys.Range(func(_, v interface{}) bool { - if vk, ok := v.(*configstoreTables.TableVirtualKey); ok && vk != nil && vk.ID == *budget.VirtualKeyID { - calendarAligned = vk.CalendarAligned - return false // stop - } - return true - }) - } else if budget.ProviderConfigID != nil { - // Provider config budgets: look up the VK that owns this provider config - gs.virtualKeys.Range(func(_, v interface{}) bool { - if vk, ok := v.(*configstoreTables.TableVirtualKey); ok && vk != nil { - for _, pc := range vk.ProviderConfigs { - if pc.ID == *budget.ProviderConfigID { - calendarAligned = vk.CalendarAligned - return false // stop - } - } - } - return true - }) - } - - if calendarAligned { + // Any budget and rate limit can be calendar aligned + if budget.CalendarAligned { // Calendar-aligned: reset when we've entered a genuinely new calendar period. currentPeriodStart := configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, now) if currentPeriodStart.After(budget.LastReset) { @@ -1589,7 +1235,6 @@ func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) newLastReset = now } } - if shouldReset { // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) copiedBudget := *budget @@ -1599,20 +1244,16 @@ func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) gs.LastDBUsagesBudgetsMu.Lock() gs.LastDBUsagesBudgets[copiedBudget.ID] = 0 gs.LastDBUsagesBudgetsMu.Unlock() - // Atomically replace the entry using the original key gs.budgets.Store(key, &copiedBudget) resetBudgets = append(resetBudgets, &copiedBudget) - // Update all VKs, teams, customers, and provider configs that reference this budget - gs.updateBudgetReferences(&copiedBudget) - + gs.updateBudgetReferences(ctx, &copiedBudget) gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", copiedBudget.ID, oldUsage)) } return true // continue }) - return resetBudgets } @@ -1620,71 +1261,92 @@ func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) func (gs *LocalGovernanceStore) ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit { now := time.Now() var resetRateLimits []*configstoreTables.TableRateLimit - - gs.rateLimits.Range(func(key, value interface{}) bool { + gs.rateLimits.Range(func(key, value any) bool { // Type-safe conversion rateLimit, ok := value.(*configstoreTables.TableRateLimit) if !ok || rateLimit == nil { return true // continue } - - needsReset := false - // Check if token reset is needed - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + tokenNeedsReset := false + requestNeedsReset := false + // Any budget and rate limit can be calendar aligned + if rateLimit.CalendarAligned { + // Calendar-aligned: reset when we've entered a genuinely new calendar period. + if rateLimit.TokenResetDuration != nil { + currentPeriodStart := configstoreTables.GetCalendarPeriodStart(*rateLimit.TokenResetDuration, now) + if currentPeriodStart.After(rateLimit.TokenLastReset) { + tokenNeedsReset = true + } + } + if rateLimit.RequestResetDuration != nil { + currentPeriodStart := configstoreTables.GetCalendarPeriodStart(*rateLimit.RequestResetDuration, now) + if currentPeriodStart.After(rateLimit.RequestLastReset) { + requestNeedsReset = true + } + } + } else { + // Rolling duration: reset after the configured duration has elapsed + if rateLimit.TokenResetDuration != nil { + duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration) + if err != nil { + gs.logger.Error("invalid budget reset duration %s: %v", *rateLimit.TokenResetDuration, err) + return true // continue + } if now.Sub(rateLimit.TokenLastReset) >= duration { - needsReset = true + tokenNeedsReset = true } } - } - // Check if request reset is needed - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if rateLimit.RequestResetDuration != nil { + duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration) + if err != nil { + gs.logger.Error("invalid budget reset duration %s: %v", *rateLimit.RequestResetDuration, err) + return true // continue + } if now.Sub(rateLimit.RequestLastReset) >= duration { - needsReset = true + requestNeedsReset = true } } } - - if needsReset { - // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) - copiedRateLimit := *rateLimit - - // Reset token limits if expired - if copiedRateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.TokenResetDuration); err == nil { - if now.Sub(copiedRateLimit.TokenLastReset) >= duration { - copiedRateLimit.TokenCurrentUsage = 0 - copiedRateLimit.TokenLastReset = now - gs.LastDBUsagesRateLimitsTokensMu.Lock() - gs.LastDBUsagesTokensRateLimits[copiedRateLimit.ID] = 0 - gs.LastDBUsagesRateLimitsTokensMu.Unlock() - } + // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) + copiedRateLimit := *rateLimit + // Reset token limits if expired + if tokenNeedsReset && copiedRateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.TokenResetDuration); err == nil { + if now.Sub(copiedRateLimit.TokenLastReset) >= duration { + copiedRateLimit.TokenCurrentUsage = 0 + copiedRateLimit.TokenLastReset = now + gs.LastDBUsagesRateLimitsTokensMu.Lock() + gs.LastDBUsagesTokensRateLimits[copiedRateLimit.ID] = 0 + gs.LastDBUsagesRateLimitsTokensMu.Unlock() } } - // Reset request limits if expired - if copiedRateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.RequestResetDuration); err == nil { - if now.Sub(copiedRateLimit.RequestLastReset) >= duration { - copiedRateLimit.RequestCurrentUsage = 0 - copiedRateLimit.RequestLastReset = now - gs.LastDBUsagesRateLimitsRequestsMu.Lock() - gs.LastDBUsagesRequestsRateLimits[copiedRateLimit.ID] = 0 - gs.LastDBUsagesRateLimitsRequestsMu.Unlock() - } + } + // Reset request limits if expired + if requestNeedsReset && copiedRateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.RequestResetDuration); err == nil { + if now.Sub(copiedRateLimit.RequestLastReset) >= duration { + copiedRateLimit.RequestCurrentUsage = 0 + copiedRateLimit.RequestLastReset = now + gs.LastDBUsagesRateLimitsRequestsMu.Lock() + gs.LastDBUsagesRequestsRateLimits[copiedRateLimit.ID] = 0 + gs.LastDBUsagesRateLimitsRequestsMu.Unlock() } } - + } + // Only commit the copy + emit for DB reset + rescan references when something + // actually expired. Without this guard the 10-second tick would always call + // gs.updateRateLimitReferences (which scans every VK + provider-config) and + // return every rate limit to the caller for a redundant DB update. Mirrors + // the `if shouldReset { ... }` guard in ResetExpiredBudgetsInMemory above. + if tokenNeedsReset || requestNeedsReset { // Atomically replace the entry using the original key gs.rateLimits.Store(key, &copiedRateLimit) resetRateLimits = append(resetRateLimits, &copiedRateLimit) - // Update all VKs and provider configs that reference this rate limit - gs.updateRateLimitReferences(&copiedRateLimit) + gs.updateRateLimitReferences(ctx, &copiedRateLimit) } return true // continue }) - return resetRateLimits } @@ -1763,7 +1425,6 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin if gs.configStore == nil { return nil } - // This is to prevent nil pointer dereference if tokenBaselines == nil { tokenBaselines = map[string]int64{} @@ -1771,7 +1432,6 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin if requestBaselines == nil { requestBaselines = map[string]int64{} } - // Collect unique rate limit IDs from virtual keys, model configs, and providers rateLimitIDs := make(map[string]bool) gs.virtualKeys.Range(func(key, value interface{}) bool { @@ -1791,7 +1451,6 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin } return true // continue }) - // Collect rate limit IDs from model configs gs.modelConfigs.Range(func(key, value interface{}) bool { mc, ok := value.(*configstoreTables.TableModelConfig) @@ -1803,7 +1462,6 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin } return true // continue }) - // Collect rate limit IDs from providers gs.providers.Range(func(key, value interface{}) bool { provider, ok := value.(*configstoreTables.TableProvider) @@ -1840,18 +1498,6 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin return true // continue }) - // Collect rate limit IDs from users (enterprise) - gs.users.Range(func(key, value interface{}) bool { - user, ok := value.(*UserGovernance) - if !ok || user == nil { - return true // continue - } - if user.RateLimitID != nil { - rateLimitIDs[*user.RateLimitID] = true - } - return true // continue - }) - // Prepare rate limit usage updates with baselines type rateLimitUpdate struct { ID string @@ -1922,14 +1568,11 @@ func (gs *LocalGovernanceStore) DumpBudgets(ctx context.Context, baselines map[s if gs.configStore == nil { return nil } - // This is to prevent nil pointer dereference if baselines == nil { baselines = map[string]float64{} } - budgets := make(map[string]*configstoreTables.TableBudget) - gs.budgets.Range(func(key, value interface{}) bool { // Type-safe conversion keyStr, keyOk := key.(string) @@ -1940,7 +1583,6 @@ func (gs *LocalGovernanceStore) DumpBudgets(ctx context.Context, baselines map[s } return true // continue iteration }) - if len(budgets) > 0 && gs.configStore != nil { if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { // Update each budget atomically using direct UPDATE to avoid deadlocks @@ -1981,7 +1623,6 @@ func (gs *LocalGovernanceStore) DumpBudgets(ctx context.Context, baselines map[s return fmt.Errorf("failed to dump budgets to database: %w", err) } } - return nil } @@ -2181,7 +1822,6 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c gs.virtualKeys = sync.Map{} gs.teams = sync.Map{} gs.customers = sync.Map{} - gs.users = sync.Map{} gs.budgets = sync.Map{} gs.rateLimits = sync.Map{} gs.modelConfigs = sync.Map{} @@ -2274,7 +1914,7 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c gs.routingRules.Range(func(key, value interface{}) bool { if rules, ok := value.([]*configstoreTables.TableRoutingRule); ok { for _, rule := range rules { - if _, err := gs.GetRoutingProgram(rule); err != nil { + if _, err := gs.GetRoutingProgram(ctx, rule); err != nil { gs.logger.Warn("Failed to pre-compile routing program for rule %s: %v", rule.Name, err) } } @@ -2304,23 +1944,24 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c gs.LastDBUsagesRateLimitsRequestsMu.Unlock() } -// UTILITY FUNCTIONS - // collectRateLimitsFromHierarchy collects rate limits and their metadata from the hierarchy (Provider Configs → VK → Team → Customer) -func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) ([]*configstoreTables.TableRateLimit, []string) { +func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) map[string][]*configstoreTables.TableRateLimit { if vk == nil { - return nil, nil + return nil } - var rateLimits []*configstoreTables.TableRateLimit - var rateLimitNames []string + rateLimitsWithCategories := map[string][]*configstoreTables.TableRateLimit{} + seen := map[string]bool{} for _, pc := range vk.ProviderConfigs { if pc.RateLimitID != nil && pc.Provider == string(requestedProvider) { if rateLimitValue, exists := gs.rateLimits.Load(*pc.RateLimitID); exists && rateLimitValue != nil { if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { - rateLimits = append(rateLimits, rateLimit) - rateLimitNames = append(rateLimitNames, pc.Provider) + if categoryRateLimits := rateLimitsWithCategories[pc.Provider]; categoryRateLimits == nil { + rateLimitsWithCategories[pc.Provider] = []*configstoreTables.TableRateLimit{} + } + rateLimitsWithCategories[pc.Provider] = append(rateLimitsWithCategories[pc.Provider], rateLimit) + seen[rateLimit.ID] = true } } } @@ -2329,8 +1970,11 @@ func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTa if vk.RateLimitID != nil { if rateLimitValue, exists := gs.rateLimits.Load(*vk.RateLimitID); exists && rateLimitValue != nil { if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { - rateLimits = append(rateLimits, rateLimit) - rateLimitNames = append(rateLimitNames, "VK") + if categoryRateLimits := rateLimitsWithCategories["VK"]; categoryRateLimits == nil { + rateLimitsWithCategories["VK"] = []*configstoreTables.TableRateLimit{} + } + rateLimitsWithCategories["VK"] = append(rateLimitsWithCategories["VK"], rateLimit) + seen[rateLimit.ID] = true } } } @@ -2343,8 +1987,11 @@ func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTa if team.RateLimitID != nil { if rateLimitValue, exists := gs.rateLimits.Load(*team.RateLimitID); exists && rateLimitValue != nil { if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { - rateLimits = append(rateLimits, rateLimit) - rateLimitNames = append(rateLimitNames, "Team") + if categoryRateLimits := rateLimitsWithCategories["Team"]; categoryRateLimits == nil { + rateLimitsWithCategories["Team"] = []*configstoreTables.TableRateLimit{} + } + rateLimitsWithCategories["Team"] = append(rateLimitsWithCategories["Team"], rateLimit) + seen[rateLimit.ID] = true } } } @@ -2357,8 +2004,11 @@ func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTa if customer.RateLimitID != nil { if rateLimitValue, exists := gs.rateLimits.Load(*customer.RateLimitID); exists && rateLimitValue != nil { if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { - rateLimits = append(rateLimits, rateLimit) - rateLimitNames = append(rateLimitNames, "Customer") + if categoryRateLimits := rateLimitsWithCategories["Customer"]; categoryRateLimits == nil { + rateLimitsWithCategories["Customer"] = []*configstoreTables.TableRateLimit{} + } + rateLimitsWithCategories["Customer"] = append(rateLimitsWithCategories["Customer"], rateLimit) + seen[rateLimit.ID] = true } } } @@ -2376,27 +2026,26 @@ func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTa if customer.RateLimitID != nil { if rateLimitValue, exists := gs.rateLimits.Load(*customer.RateLimitID); exists && rateLimitValue != nil { if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { - rateLimits = append(rateLimits, rateLimit) - rateLimitNames = append(rateLimitNames, "Customer") + if categoryRateLimits := rateLimitsWithCategories["Customer"]; categoryRateLimits == nil { + rateLimitsWithCategories["Customer"] = []*configstoreTables.TableRateLimit{} + } + rateLimitsWithCategories["Customer"] = append(rateLimitsWithCategories["Customer"], rateLimit) + seen[rateLimit.ID] = true } } } } } } - - return rateLimits, rateLimitNames + return rateLimitsWithCategories } -// collectBudgetsFromHierarchy collects budgets and their metadata from the hierarchy (Provider Configs → VK → Team → Customer) -func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) ([]*configstoreTables.TableBudget, []string) { +// collectBudgetsFromHierarchy collects budgets and their metadata from the hierarchy (Provider Configs → VK → Customer -> User -> Team → BusinessUnit) +func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(_ context.Context, vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) EntityWiseBudgets { if vk == nil { - return nil, nil + return nil } - - var budgets []*configstoreTables.TableBudget - var budgetNames []string - + entityWiseBudgets := make(EntityWiseBudgets) // Collect all budgets in hierarchy order using lock-free sync.Map access (Provider Configs → VK → Team → Customer) seen := make(map[string]bool) for _, pc := range vk.ProviderConfigs { @@ -2410,14 +2059,15 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable } if budgetValue, exists := gs.budgets.Load(b.ID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { - budgets = append(budgets, budget) - budgetNames = append(budgetNames, pc.Provider) + if categoryBudgets := entityWiseBudgets[pc.Provider]; categoryBudgets == nil { + entityWiseBudgets[pc.Provider] = []*configstoreTables.TableBudget{} + } + entityWiseBudgets[pc.Provider] = append(entityWiseBudgets[pc.Provider], budget) seen[budget.ID] = true } } } } - // VK-level multi-budgets for _, b := range vk.Budgets { if seen[b.ID] { @@ -2425,13 +2075,14 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable } if budgetValue, exists := gs.budgets.Load(b.ID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { - budgets = append(budgets, budget) - budgetNames = append(budgetNames, "VK") + if categoryBudgets := entityWiseBudgets["VK"]; categoryBudgets == nil { + entityWiseBudgets["VK"] = []*configstoreTables.TableBudget{} + } + entityWiseBudgets["VK"] = append(entityWiseBudgets["VK"], budget) seen[budget.ID] = true } } } - var teamCustomerID string if vk.TeamID != nil { if teamValue, exists := gs.teams.Load(*vk.TeamID); exists && teamValue != nil { @@ -2439,8 +2090,11 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable if team.BudgetID != nil { if budgetValue, exists := gs.budgets.Load(*team.BudgetID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { - budgets = append(budgets, budget) - budgetNames = append(budgetNames, "Team") + if categoryBudgets := entityWiseBudgets["Team"]; categoryBudgets == nil { + entityWiseBudgets["Team"] = []*configstoreTables.TableBudget{} + } + entityWiseBudgets["Team"] = append(entityWiseBudgets["Team"], budget) + seen[budget.ID] = true } } } @@ -2453,8 +2107,11 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable if customer.BudgetID != nil { if budgetValue, exists := gs.budgets.Load(*customer.BudgetID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { - budgets = append(budgets, budget) - budgetNames = append(budgetNames, "Customer") + if categoryBudgets := entityWiseBudgets["Customer"]; categoryBudgets == nil { + entityWiseBudgets["Customer"] = []*configstoreTables.TableBudget{} + } + entityWiseBudgets["Customer"] = append(entityWiseBudgets["Customer"], budget) + seen[budget.ID] = true } } } @@ -2464,7 +2121,6 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable } } } - // Check Customer budget if VK directly belongs to a customer (skip if already collected via team) if vk.CustomerID != nil && (teamCustomerID == "" || *vk.CustomerID != teamCustomerID) { if customerValue, exists := gs.customers.Load(*vk.CustomerID); exists && customerValue != nil { @@ -2472,46 +2128,48 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable if customer.BudgetID != nil { if budgetValue, exists := gs.budgets.Load(*customer.BudgetID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { - budgets = append(budgets, budget) - budgetNames = append(budgetNames, "Customer") + if categoryBudgets := entityWiseBudgets["Customer"]; categoryBudgets == nil { + entityWiseBudgets["Customer"] = []*configstoreTables.TableBudget{} + } + entityWiseBudgets["Customer"] = append(entityWiseBudgets["Customer"], budget) + seen[budget.ID] = true } } } } } } - - return budgets, budgetNames + return entityWiseBudgets } // collectBudgetIDsFromMemory collects budget IDs from in-memory store data (lock-free) func (gs *LocalGovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { - budgets, _ := gs.collectBudgetsFromHierarchy(vk, provider) - - budgetIDs := make([]string, len(budgets)) - for i, budget := range budgets { - budgetIDs[i] = budget.ID + budgetsWithCategory := gs.collectBudgetsFromHierarchy(ctx, vk, provider) + budgetIDs := []string{} + for _, budgets := range budgetsWithCategory { + for _, budget := range budgets { + budgetIDs = append(budgetIDs, budget.ID) + } } - return budgetIDs } // collectRateLimitIDsFromMemory collects rate limit IDs from in-memory store data (lock-free) -func (gs *LocalGovernanceStore) collectRateLimitIDsFromMemory(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { - rateLimits, _ := gs.collectRateLimitsFromHierarchy(vk, provider) - - rateLimitIDs := make([]string, len(rateLimits)) - for i, rateLimit := range rateLimits { - rateLimitIDs[i] = rateLimit.ID +func (gs *LocalGovernanceStore) collectRateLimitIDsFromMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { + rateLimitsWithCategories := gs.collectRateLimitsFromHierarchy(ctx, vk, provider) + rateLimitIDs := []string{} + for _, rateLimits := range rateLimitsWithCategories { + for _, rateLimit := range rateLimits { + rateLimitIDs = append(rateLimitIDs, rateLimit.ID) + } } - return rateLimitIDs } // PUBLIC API METHODS // CreateVirtualKeyInMemory adds a new virtual key to the in-memory store (lock-free) -func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { +func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey) { if vk == nil { return // Nothing to create } @@ -2542,7 +2200,7 @@ func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.T } // UpdateVirtualKeyInMemory updates an existing virtual key in the in-memory store (lock-free) -func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) { +func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) { if vk == nil { return // Nothing to update } @@ -2660,12 +2318,12 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.T } gs.virtualKeys.Store(vk.Value, &clone) } else { - gs.CreateVirtualKeyInMemory(vk) + gs.CreateVirtualKeyInMemory(ctx, vk) } } // DeleteVirtualKeyInMemory removes a virtual key from the in-memory store -func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(vkID string) { +func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(ctx context.Context, vkID string) { if vkID == "" { return // Nothing to delete } @@ -2709,7 +2367,7 @@ func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(vkID string) { } // CreateTeamInMemory adds a new team to the in-memory store (lock-free) -func (gs *LocalGovernanceStore) CreateTeamInMemory(team *configstoreTables.TableTeam) { +func (gs *LocalGovernanceStore) CreateTeamInMemory(ctx context.Context, team *configstoreTables.TableTeam) { if team == nil { return // Nothing to create } @@ -2728,7 +2386,7 @@ func (gs *LocalGovernanceStore) CreateTeamInMemory(team *configstoreTables.Table } // UpdateTeamInMemory updates an existing team in the in-memory store (lock-free) -func (gs *LocalGovernanceStore) UpdateTeamInMemory(team *configstoreTables.TableTeam, budgetBaselines map[string]float64) { +func (gs *LocalGovernanceStore) UpdateTeamInMemory(ctx context.Context, team *configstoreTables.TableTeam, budgetBaselines map[string]float64) { if team == nil { return // Nothing to update } @@ -2779,12 +2437,12 @@ func (gs *LocalGovernanceStore) UpdateTeamInMemory(team *configstoreTables.Table gs.teams.Store(team.ID, &clone) } else { - gs.CreateTeamInMemory(team) + gs.CreateTeamInMemory(ctx, team) } } // DeleteTeamInMemory removes a team from the in-memory store (lock-free) -func (gs *LocalGovernanceStore) DeleteTeamInMemory(teamID string) { +func (gs *LocalGovernanceStore) DeleteTeamInMemory(ctx context.Context, teamID string) { if teamID == "" { return // Nothing to delete } @@ -2823,26 +2481,23 @@ func (gs *LocalGovernanceStore) DeleteTeamInMemory(teamID string) { } // CreateCustomerInMemory adds a new customer to the in-memory store (lock-free) -func (gs *LocalGovernanceStore) CreateCustomerInMemory(customer *configstoreTables.TableCustomer) { +func (gs *LocalGovernanceStore) CreateCustomerInMemory(ctx context.Context, customer *configstoreTables.TableCustomer) { if customer == nil { return // Nothing to create } - // Create associated budget if exists if customer.Budget != nil { gs.budgets.Store(customer.Budget.ID, customer.Budget) } - // Create associated rate limit if exists if customer.RateLimit != nil { gs.rateLimits.Store(customer.RateLimit.ID, customer.RateLimit) } - gs.customers.Store(customer.ID, customer) } // UpdateCustomerInMemory updates an existing customer in the in-memory store (lock-free) -func (gs *LocalGovernanceStore) UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) { +func (gs *LocalGovernanceStore) UpdateCustomerInMemory(ctx context.Context, customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) { if customer == nil { return // Nothing to update } @@ -2891,16 +2546,15 @@ func (gs *LocalGovernanceStore) UpdateCustomerInMemory(customer *configstoreTabl gs.customers.Store(customer.ID, &clone) } else { - gs.CreateCustomerInMemory(customer) + gs.CreateCustomerInMemory(ctx, customer) } } // DeleteCustomerInMemory removes a customer from the in-memory store (lock-free) -func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { +func (gs *LocalGovernanceStore) DeleteCustomerInMemory(ctx context.Context, customerID string) { if customerID == "" { return // Nothing to delete } - // Get customer to check for associated budget and rate limit if customerValue, exists := gs.customers.Load(customerID); exists && customerValue != nil { if customer, ok := customerValue.(*configstoreTables.TableCustomer); ok && customer != nil { @@ -2914,7 +2568,6 @@ func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { } } } - // Set customer_id to null for all virtual keys associated with the customer // Iterate through all VKs since customer.VirtualKeys may not be populated gs.virtualKeys.Range(func(key, value interface{}) bool { @@ -2930,7 +2583,6 @@ func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { } return true // continue iteration }) - // Set customer_id to null for all teams associated with the customer // Iterate through all teams since customer.Teams may not be populated gs.teams.Range(func(key, value interface{}) bool { @@ -2946,128 +2598,37 @@ func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { } return true // continue iteration }) - gs.customers.Delete(customerID) } // GetUserGovernance retrieves user governance data by user ID (enterprise-only, lock-free) -func (gs *LocalGovernanceStore) GetUserGovernance(userID string) (*UserGovernance, bool) { - value, exists := gs.users.Load(userID) - if !exists || value == nil { - return nil, false - } - ug, ok := value.(*UserGovernance) - if !ok || ug == nil { - return nil, false - } - return ug, true +func (gs *LocalGovernanceStore) GetUserGovernance(ctx context.Context, userID string) (*UserGovernance, bool) { + // User governance is part of enterprise + return nil, false } // CreateUserGovernanceInMemory adds user governance data to the in-memory store (enterprise-only) -func (gs *LocalGovernanceStore) CreateUserGovernanceInMemory(userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) { - if userID == "" { - return - } - - ug := &UserGovernance{ - UserID: userID, - } - - if budget != nil { - ug.BudgetID = &budget.ID - ug.Budget = budget - gs.budgets.Store(budget.ID, budget) - } - - if rateLimit != nil { - ug.RateLimitID = &rateLimit.ID - ug.RateLimit = rateLimit - gs.rateLimits.Store(rateLimit.ID, rateLimit) - } - - gs.users.Store(userID, ug) +func (gs *LocalGovernanceStore) CreateUserGovernanceInMemory(ctx context.Context, userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) { + // NoOp + // Available in enterprise } // UpdateUserGovernanceInMemory updates user governance data in the in-memory store (enterprise-only) -func (gs *LocalGovernanceStore) UpdateUserGovernanceInMemory(userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) { - if userID == "" { - return - } - - existingValue, exists := gs.users.Load(userID) - var existingUG *UserGovernance - if exists && existingValue != nil { - existingUG, _ = existingValue.(*UserGovernance) - } - - ug := &UserGovernance{ - UserID: userID, - } - - // Handle budget updates - if budget != nil { - ug.BudgetID = &budget.ID - // Preserve existing usage from memory when updating budget config - if existingBudgetValue, exists := gs.budgets.Load(budget.ID); exists && existingBudgetValue != nil { - if existingBudget, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && existingBudget != nil { - budget.CurrentUsage = existingBudget.CurrentUsage - budget.LastReset = existingBudget.LastReset - } - } - ug.Budget = budget - gs.budgets.Store(budget.ID, budget) - } else if existingUG != nil && existingUG.BudgetID != nil { - // Budget was removed, delete from memory - gs.budgets.Delete(*existingUG.BudgetID) - } - - // Handle rate limit updates - if rateLimit != nil { - ug.RateLimitID = &rateLimit.ID - // Preserve existing usage from memory when updating rate limit config - if existingRateLimitValue, exists := gs.rateLimits.Load(rateLimit.ID); exists && existingRateLimitValue != nil { - if existingRateLimit, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && existingRateLimit != nil { - rateLimit.TokenCurrentUsage = existingRateLimit.TokenCurrentUsage - rateLimit.TokenLastReset = existingRateLimit.TokenLastReset - rateLimit.RequestCurrentUsage = existingRateLimit.RequestCurrentUsage - rateLimit.RequestLastReset = existingRateLimit.RequestLastReset - } - } - ug.RateLimit = rateLimit - gs.rateLimits.Store(rateLimit.ID, rateLimit) - } else if existingUG != nil && existingUG.RateLimitID != nil { - // Rate limit was removed, delete from memory - gs.rateLimits.Delete(*existingUG.RateLimitID) - } - - gs.users.Store(userID, ug) +func (gs *LocalGovernanceStore) UpdateUserGovernanceInMemory(ctx context.Context, userID string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) { + // NoOp + // Available in enterprise } // DeleteUserGovernanceInMemory removes user governance data from the in-memory store (enterprise-only) -func (gs *LocalGovernanceStore) DeleteUserGovernanceInMemory(userID string) { - if userID == "" { - return - } - - // Get existing user governance to clean up associated budgets/rate limits - if existingValue, exists := gs.users.Load(userID); exists && existingValue != nil { - if ug, ok := existingValue.(*UserGovernance); ok && ug != nil { - if ug.BudgetID != nil { - gs.budgets.Delete(*ug.BudgetID) - } - if ug.RateLimitID != nil { - gs.rateLimits.Delete(*ug.RateLimitID) - } - } - } - - gs.users.Delete(userID) +func (gs *LocalGovernanceStore) DeleteUserGovernanceInMemory(ctx context.Context, userID string) { + // NoOp + // Available in enterprise } // UpdateModelConfigInMemory adds or updates a model config in the in-memory store (lock-free) // Preserves existing usage values when updating budgets and rate limits // Returns the updated model config with potentially modified usage values -func (gs *LocalGovernanceStore) UpdateModelConfigInMemory(mc *configstoreTables.TableModelConfig) *configstoreTables.TableModelConfig { +func (gs *LocalGovernanceStore) UpdateModelConfigInMemory(ctx context.Context, mc *configstoreTables.TableModelConfig) *configstoreTables.TableModelConfig { if mc == nil { return nil // Nothing to update } @@ -3113,7 +2674,7 @@ func (gs *LocalGovernanceStore) UpdateModelConfigInMemory(mc *configstoreTables. } // DeleteModelConfigInMemory removes a model config from the in-memory store (lock-free) -func (gs *LocalGovernanceStore) DeleteModelConfigInMemory(mcID string) { +func (gs *LocalGovernanceStore) DeleteModelConfigInMemory(ctx context.Context, mcID string) { if mcID == "" { return // Nothing to delete } @@ -3146,7 +2707,7 @@ func (gs *LocalGovernanceStore) DeleteModelConfigInMemory(mcID string) { // UpdateProviderInMemory adds or updates a provider in the in-memory store (lock-free) // Preserves existing usage values when updating budgets and rate limits // Returns the updated provider with potentially modified usage values -func (gs *LocalGovernanceStore) UpdateProviderInMemory(provider *configstoreTables.TableProvider) *configstoreTables.TableProvider { +func (gs *LocalGovernanceStore) UpdateProviderInMemory(ctx context.Context, provider *configstoreTables.TableProvider) *configstoreTables.TableProvider { if provider == nil { return nil // Nothing to update } @@ -3182,11 +2743,10 @@ func (gs *LocalGovernanceStore) UpdateProviderInMemory(provider *configstoreTabl } // DeleteProviderInMemory removes a provider from the in-memory store (lock-free) -func (gs *LocalGovernanceStore) DeleteProviderInMemory(providerName string) { +func (gs *LocalGovernanceStore) DeleteProviderInMemory(ctx context.Context, providerName string) { if providerName == "" { return // Nothing to delete } - // Get provider to check for associated budget/rate limit if providerValue, exists := gs.providers.Load(providerName); exists && providerValue != nil { if provider, ok := providerValue.(*configstoreTables.TableProvider); ok && provider != nil { @@ -3201,14 +2761,13 @@ func (gs *LocalGovernanceStore) DeleteProviderInMemory(providerName string) { } } } - gs.providers.Delete(providerName) } // Helper functions // updateBudgetReferences updates all VKs, teams, customers, and provider configs that reference a reset budget -func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreTables.TableBudget) { +func (gs *LocalGovernanceStore) updateBudgetReferences(ctx context.Context, resetBudget *configstoreTables.TableBudget) { budgetID := resetBudget.ID // Update VKs that reference this budget gs.virtualKeys.Range(func(key, value interface{}) bool { @@ -3226,7 +2785,6 @@ func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreT needsUpdate = true } } - // Check provider config budgets if vk.ProviderConfigs != nil { for i := range clone.ProviderConfigs { @@ -3238,13 +2796,11 @@ func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreT } } } - if needsUpdate { gs.virtualKeys.Store(key, &clone) } return true // continue }) - // Update teams that reference this budget gs.teams.Range(func(key, value interface{}) bool { team, ok := value.(*configstoreTables.TableTeam) @@ -3258,7 +2814,6 @@ func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreT } return true // continue }) - // Update customers that reference this budget gs.customers.Range(func(key, value interface{}) bool { customer, ok := value.(*configstoreTables.TableCustomer) @@ -3272,24 +2827,10 @@ func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreT } return true // continue }) - - // Update users that reference this budget (enterprise) - gs.users.Range(func(key, value interface{}) bool { - user, ok := value.(*UserGovernance) - if !ok || user == nil { - return true // continue - } - if user.BudgetID != nil && *user.BudgetID == budgetID { - clone := *user - clone.Budget = resetBudget - gs.users.Store(key, &clone) - } - return true // continue - }) } // updateRateLimitReferences updates all VKs, teams, customers, users and provider configs that reference a reset rate limit -func (gs *LocalGovernanceStore) updateRateLimitReferences(resetRateLimit *configstoreTables.TableRateLimit) { +func (gs *LocalGovernanceStore) updateRateLimitReferences(ctx context.Context, resetRateLimit *configstoreTables.TableRateLimit) { rateLimitID := resetRateLimit.ID // Update VKs that reference this rate limit gs.virtualKeys.Range(func(key, value interface{}) bool { @@ -3321,7 +2862,6 @@ func (gs *LocalGovernanceStore) updateRateLimitReferences(resetRateLimit *config } return true // continue }) - // Update teams that reference this rate limit gs.teams.Range(func(key, value interface{}) bool { team, ok := value.(*configstoreTables.TableTeam) @@ -3335,7 +2875,6 @@ func (gs *LocalGovernanceStore) updateRateLimitReferences(resetRateLimit *config } return true // continue }) - // Update customers that reference this rate limit gs.customers.Range(func(key, value interface{}) bool { customer, ok := value.(*configstoreTables.TableCustomer) @@ -3349,20 +2888,6 @@ func (gs *LocalGovernanceStore) updateRateLimitReferences(resetRateLimit *config } return true // continue }) - - // Update users that reference this rate limit (enterprise) - gs.users.Range(func(key, value interface{}) bool { - user, ok := value.(*UserGovernance) - if !ok || user == nil { - return true // continue - } - if user.RateLimitID != nil && *user.RateLimitID == rateLimitID { - clone := *user - clone.RateLimit = resetRateLimit - gs.users.Store(key, &clone) - } - return true // continue - }) } // HasRoutingRules checks if there are any routing rules configured @@ -3377,7 +2902,7 @@ func (gs *LocalGovernanceStore) HasRoutingRules(ctx context.Context) bool { } // GetAllRoutingRules gets all routing rules from in-memory cache -func (gs *LocalGovernanceStore) GetAllRoutingRules() []*configstoreTables.TableRoutingRule { +func (gs *LocalGovernanceStore) GetAllRoutingRules(ctx context.Context) []*configstoreTables.TableRoutingRule { var result []*configstoreTables.TableRoutingRule // Iterate through all cached rules @@ -3403,7 +2928,7 @@ func (gs *LocalGovernanceStore) GetAllRoutingRules() []*configstoreTables.TableR // GetScopedRoutingRules retrieves routing rules by scope and scope ID (from in-memory cache) // Rules are already sorted by priority ASC (0 is highest priority) -func (gs *LocalGovernanceStore) GetScopedRoutingRules(scope string, scopeID string) []*configstoreTables.TableRoutingRule { +func (gs *LocalGovernanceStore) GetScopedRoutingRules(ctx context.Context, scope string, scopeID string) []*configstoreTables.TableRoutingRule { // Build cache key: "scope:scopeID" (scopeID empty string for global) var key string if scope == "global" { @@ -3437,7 +2962,7 @@ func (gs *LocalGovernanceStore) GetScopedRoutingRules(scope string, scopeID stri // GetRoutingProgram compiles a CEL expression and caches the resulting program // Uses the singleton CEL environment for efficiency // Returns error if compilation fails -func (gs *LocalGovernanceStore) GetRoutingProgram(rule *configstoreTables.TableRoutingRule) (cel.Program, error) { +func (gs *LocalGovernanceStore) GetRoutingProgram(ctx context.Context, rule *configstoreTables.TableRoutingRule) (cel.Program, error) { if rule == nil { return nil, fmt.Errorf("routing rule cannot be nil") } @@ -3558,7 +3083,7 @@ func (gs *LocalGovernanceStore) GetBudgetAndRateLimitStatus(ctx context.Context, // Fall back to model-only config (if exists) // Uses findModelOnlyConfig for cross-provider model name normalization - if modelConfig, _ := gs.findModelOnlyConfig(model); modelConfig != nil { + if modelConfig, _ := gs.findModelOnlyConfig(ctx, model); modelConfig != nil { // Get rate limit status if modelConfig.RateLimitID != nil { if rateLimitValue, ok := gs.rateLimits.Load(*modelConfig.RateLimitID); ok && rateLimitValue != nil { @@ -3722,11 +3247,10 @@ func (gs *LocalGovernanceStore) GetBudgetAndRateLimitStatus(ctx context.Context, } // UpdateRoutingRuleInMemory updates a routing rule in the in-memory cache -func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(rule *configstoreTables.TableRoutingRule) error { +func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(ctx context.Context, rule *configstoreTables.TableRoutingRule) error { if rule == nil { return fmt.Errorf("routing rule cannot be nil") } - // First, remove the rule from ALL scopes (in case it was moved from one scope to another) gs.routingRules.Range(func(key, value interface{}) bool { rules, ok := value.([]*configstoreTables.TableRoutingRule) @@ -3752,7 +3276,6 @@ func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(rule *configstoreTable } return true }) - // Build cache key for the new scope var key string if rule.Scope == "global" { @@ -3764,7 +3287,6 @@ func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(rule *configstoreTable } key = fmt.Sprintf("%s:%s", rule.Scope, scopeID) } - // Load existing rules for this scope var rules []*configstoreTables.TableRoutingRule if value, ok := gs.routingRules.Load(key); ok { @@ -3772,38 +3294,31 @@ func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(rule *configstoreTable rules = existing } } - // Add the rule to the new scope rules = append(rules, rule) - // Sort by priority ASC (0 is highest priority, higher numbers are lower priority) sort.Slice(rules, func(i, j int) bool { return rules[i].Priority < rules[j].Priority }) - // Store back in cache gs.routingRules.Store(key, rules) - // Invalidate compiled program cache for this rule (expression may have changed) gs.compiledRoutingPrograms.Delete(rule.ID) - // Recompile the program immediately to update cache with fresh compilation - if _, err := gs.GetRoutingProgram(rule); err != nil { + if _, err := gs.GetRoutingProgram(ctx, rule); err != nil { gs.logger.Warn("Failed to recompile routing program for rule %s: %v", rule.Name, err) } - return nil } // DeleteRoutingRuleInMemory removes a routing rule from the in-memory cache -func (gs *LocalGovernanceStore) DeleteRoutingRuleInMemory(id string) error { +func (gs *LocalGovernanceStore) DeleteRoutingRuleInMemory(ctx context.Context, id string) error { // Loop over all rules and delete the one with the matching id gs.routingRules.Range(func(key, value interface{}) bool { rules, ok := value.([]*configstoreTables.TableRoutingRule) if !ok { return true } - // Find and filter out the rule with matching ID var filteredRules []*configstoreTables.TableRoutingRule for _, r := range rules { @@ -3811,7 +3326,6 @@ func (gs *LocalGovernanceStore) DeleteRoutingRuleInMemory(id string) error { filteredRules = append(filteredRules, r) } } - // Update or delete the key if len(filteredRules) == 0 { gs.routingRules.Delete(key) @@ -3820,9 +3334,7 @@ func (gs *LocalGovernanceStore) DeleteRoutingRuleInMemory(id string) error { } return true }) - // Invalidate compiled program cache for this rule gs.compiledRoutingPrograms.Delete(id) - return nil } diff --git a/plugins/governance/store_test.go b/plugins/governance/store_test.go index 9f328b2047..c1d07204d1 100644 --- a/plugins/governance/store_test.go +++ b/plugins/governance/store_test.go @@ -53,7 +53,7 @@ func TestGovernanceStore_GetVirtualKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - vk, exists := store.GetVirtualKey(tt.vkValue) + vk, exists := store.GetVirtualKey(context.Background(), tt.vkValue) if tt.wantNil { assert.False(t, exists) assert.Nil(t, vk) @@ -85,7 +85,7 @@ func TestGovernanceStore_ConcurrentReads(t *testing.T) { go func() { defer wg.Done() for j := 0; j < 100; j++ { - vk, exists := store.GetVirtualKey("sk-bf-test") + vk, exists := store.GetVirtualKey(context.Background(), "sk-bf-test") if !exists || vk == nil { errorCount.Add(1) return @@ -114,7 +114,7 @@ func TestGovernanceStore_CheckBudget_SingleBudget(t *testing.T) { require.NoError(t, err) // Retrieve VK with budget - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") tests := []struct { name string @@ -152,8 +152,8 @@ func TestGovernanceStore_CheckBudget_SingleBudget(t *testing.T) { Budgets: []configstoreTables.TableBudget{*testBudget}, }, nil) - testVK, _ = testStore.GetVirtualKey("sk-bf-test") - err := testStore.CheckBudget(context.Background(), testVK, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + testVK, _ = testStore.GetVirtualKey(context.Background(), "sk-bf-test") + _, err := testStore.CheckVirtualKeyBudget(context.Background(), testVK, &EvaluationRequest{Provider: schemas.OpenAI}, nil) if tt.shouldErr { assert.Error(t, err, "Expected error for usage check") } else { @@ -190,10 +190,10 @@ func TestGovernanceStore_CheckBudget_HierarchyValidation(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") // Test: All budgets under limit should pass - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err, "Should pass when all budgets are under limit") // Test: If VK budget exceeds limit, should fail @@ -207,7 +207,7 @@ func TestGovernanceStore_CheckBudget_HierarchyValidation(t *testing.T) { } } } - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) require.Error(t, err, "Should fail when VK budget exceeds limit") } @@ -232,8 +232,8 @@ func TestGovernanceStore_MultiBudget_AllUnderLimit(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err, "Should pass when all budgets are under limit") } @@ -257,8 +257,8 @@ func TestGovernanceStore_MultiBudget_SmallBudgetExceeded(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) require.Error(t, err, "Should fail when hourly budget is exceeded even though daily is fine") assert.Contains(t, err.Error(), "budget exceeded") } @@ -283,8 +283,8 @@ func TestGovernanceStore_MultiBudget_LargeBudgetExceeded(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) require.Error(t, err, "Should fail when daily budget is exceeded even though hourly is fine") assert.Contains(t, err.Error(), "budget exceeded") } @@ -308,7 +308,7 @@ func TestGovernanceStore_MultiBudget_UsageUpdatesAllBudgets(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") // Simulate a $3.50 request err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 3.50) @@ -334,7 +334,7 @@ func TestGovernanceStore_MultiBudget_UsageUpdatesAllBudgets(t *testing.T) { assert.InDelta(t, 10.50, dailyVal.(*configstoreTables.TableBudget).CurrentUsage, 0.01, "Daily budget should accumulate") // Now CheckBudget should fail (hourly exceeded) - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) require.Error(t, err, "Should fail after usage exceeds hourly budget") assert.Contains(t, err.Error(), "budget exceeded") } @@ -359,8 +359,8 @@ func TestGovernanceStore_MultiBudget_ProviderConfigBudgets(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) require.Error(t, err, "Should fail when provider config hourly budget is exceeded") assert.Contains(t, err.Error(), "budget exceeded") } @@ -388,10 +388,10 @@ func TestGovernanceStore_MultiBudget_VKAndProviderConfigCombined(t *testing.T) { }, nil) require.NoError(t, err) - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") // Provider config budget exceeded → should block even though VK budget is fine - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) require.Error(t, err, "Should fail: provider config budget exceeded even though VK budget is fine") assert.Contains(t, err.Error(), "budget exceeded") } @@ -474,7 +474,7 @@ func TestGovernanceStore_MultiBudget_UsageDrivesBlockAfterRequests(t *testing.T) resolver := NewBudgetResolver(store, nil, logger, nil) // Request 1: $0.80 — both budgets fine - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 0.80) require.NoError(t, err) @@ -483,7 +483,7 @@ func TestGovernanceStore_MultiBudget_UsageDrivesBlockAfterRequests(t *testing.T) assertDecision(t, DecisionAllow, result) // Request 2: $0.80 — still fine ($1.60 total) - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 0.80) require.NoError(t, err) @@ -492,7 +492,7 @@ func TestGovernanceStore_MultiBudget_UsageDrivesBlockAfterRequests(t *testing.T) assertDecision(t, DecisionAllow, result) // Request 3: $0.80 — pushes hourly to $2.40 > $2.00 limit → blocked - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 0.80) require.NoError(t, err) @@ -542,11 +542,11 @@ func TestGovernanceStore_MultiBudget_CalendarAligned(t *testing.T) { require.NoError(t, err) // Verify VK-level calendar_aligned is set - vk, _ = store.GetVirtualKey("sk-bf-test") + vk, _ = store.GetVirtualKey(context.Background(), "sk-bf-test") assert.True(t, vk.CalendarAligned, "VK should have calendar_aligned=true") // Both under limit — should pass - err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + _, err = store.CheckVirtualKeyBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) assert.NoError(t, err) } @@ -568,26 +568,26 @@ func TestGovernanceStore_MultiBudget_InMemoryCreateAndDelete(t *testing.T) { } // Create - store.CreateVirtualKeyInMemory(vk) + store.CreateVirtualKeyInMemory(context.Background(), vk) _, exists := store.budgets.Load("b1") assert.True(t, exists, "Budget b1 should be in memory after create") _, exists = store.budgets.Load("b2") assert.True(t, exists, "Budget b2 should be in memory after create") - retrieved, found := store.GetVirtualKey("sk-bf-test") + retrieved, found := store.GetVirtualKey(context.Background(), "sk-bf-test") require.True(t, found) assert.Len(t, retrieved.Budgets, 2, "VK should have 2 budgets") // Delete - store.DeleteVirtualKeyInMemory("vk1") + store.DeleteVirtualKeyInMemory(context.Background(), "vk1") _, exists = store.budgets.Load("b1") assert.False(t, exists, "Budget b1 should be removed after delete") _, exists = store.budgets.Load("b2") assert.False(t, exists, "Budget b2 should be removed after delete") - _, found = store.GetVirtualKey("sk-bf-test") + _, found = store.GetVirtualKey(context.Background(), "sk-bf-test") assert.False(t, found, "VK should not be found after delete") } @@ -609,7 +609,7 @@ func TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests(t *testing.T) { assert.NoError(t, err, "Rate limit update should succeed") // Retrieve the updated rate limit from the main RateLimits map - governanceData := store.GetGovernanceData() + governanceData := store.GetGovernanceData(context.Background()) updatedRateLimit, exists := governanceData.RateLimits["rl1"] require.True(t, exists, "Rate limit should exist") require.NotNil(t, updatedRateLimit) @@ -622,7 +622,7 @@ func TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests(t *testing.T) { assert.NoError(t, err, "Rate limit update should succeed") // Retrieve the updated rate limit again - governanceData = store.GetGovernanceData() + governanceData = store.GetGovernanceData(context.Background()) updatedRateLimit, exists = governanceData.RateLimits["rl1"] require.True(t, exists, "Rate limit should exist") require.NotNil(t, updatedRateLimit) @@ -663,7 +663,7 @@ func TestGovernanceStore_ResetExpiredRateLimits(t *testing.T) { assert.NoError(t, err, "Reset should succeed") // Retrieve the updated VK to check rate limit changes - updatedVK, _ := store.GetVirtualKey("sk-bf-test") + updatedVK, _ := store.GetVirtualKey(context.Background(), "sk-bf-test") require.NotNil(t, updatedVK) require.NotNil(t, updatedVK.RateLimit) @@ -698,7 +698,7 @@ func TestGovernanceStore_ResetExpiredBudgets(t *testing.T) { assert.NoError(t, err, "Reset should succeed") // Retrieve the updated VK to check budget changes - updatedVK, _ := store.GetVirtualKey("sk-bf-test") + updatedVK, _ := store.GetVirtualKey(context.Background(), "sk-bf-test") require.NotNil(t, updatedVK) require.True(t, len(updatedVK.Budgets) > 0, "VK should have budgets") @@ -720,7 +720,7 @@ func TestGovernanceStore_GetAllBudgets(t *testing.T) { }, nil) require.NoError(t, err) - allBudgets := store.GetGovernanceData().Budgets + allBudgets := store.GetGovernanceData(context.Background()).Budgets assert.Equal(t, 3, len(allBudgets), "Should have 3 budgets") assert.NotNil(t, allBudgets["budget1"]) assert.NotNil(t, allBudgets["budget2"]) @@ -773,22 +773,22 @@ func TestGovernanceStore_RoutingRules_CreateAndRetrieve(t *testing.T) { } // Store rules in memory - err = store.UpdateRoutingRuleInMemory(rule1) + err = store.UpdateRoutingRuleInMemory(context.Background(), rule1) require.NoError(t, err) - err = store.UpdateRoutingRuleInMemory(rule2) + err = store.UpdateRoutingRuleInMemory(context.Background(), rule2) require.NoError(t, err) // Test retrieval by scope - globalRules := store.GetScopedRoutingRules("global", "") + globalRules := store.GetScopedRoutingRules(context.Background(), "global", "") assert.Equal(t, 1, len(globalRules)) assert.Equal(t, "Global Rule", globalRules[0].Name) - teamRules := store.GetScopedRoutingRules("team", teamID) + teamRules := store.GetScopedRoutingRules(context.Background(), "team", teamID) assert.Equal(t, 1, len(teamRules)) assert.Equal(t, "Team Rule", teamRules[0].Name) // Test ListRoutingRules - allRules := store.GetAllRoutingRules() + allRules := store.GetAllRoutingRules(context.Background()) assert.Equal(t, 2, len(allRules)) } @@ -827,12 +827,12 @@ func TestGovernanceStore_RoutingRules_PriorityOrdering(t *testing.T) { } for _, rule := range rules { - err := store.UpdateRoutingRuleInMemory(rule) + err := store.UpdateRoutingRuleInMemory(context.Background(), rule) require.NoError(t, err) } // Retrieve and verify ordering (sorted by priority ASC, so lower numbers first) - retrieved := store.GetScopedRoutingRules("global", "") + retrieved := store.GetScopedRoutingRules(context.Background(), "global", "") assert.Equal(t, 3, len(retrieved)) assert.Equal(t, 5, retrieved[0].Priority) assert.Equal(t, 10, retrieved[1].Priority) @@ -861,13 +861,13 @@ func TestGovernanceStore_RoutingRules_DisabledRulesFiltered(t *testing.T) { ScopeID: nil, } - err = store.UpdateRoutingRuleInMemory(enabledRule) + err = store.UpdateRoutingRuleInMemory(context.Background(), enabledRule) require.NoError(t, err) - err = store.UpdateRoutingRuleInMemory(disabledRule) + err = store.UpdateRoutingRuleInMemory(context.Background(), disabledRule) require.NoError(t, err) // Only enabled rules should be returned - retrieved := store.GetScopedRoutingRules("global", "") + retrieved := store.GetScopedRoutingRules(context.Background(), "global", "") assert.Equal(t, 1, len(retrieved)) assert.Equal(t, "Enabled Rule", retrieved[0].Name) } @@ -887,18 +887,18 @@ func TestGovernanceStore_RoutingRules_DeleteRule(t *testing.T) { } // Add rule - err = store.UpdateRoutingRuleInMemory(rule) + err = store.UpdateRoutingRuleInMemory(context.Background(), rule) require.NoError(t, err) - retrieved := store.GetScopedRoutingRules("global", "") + retrieved := store.GetScopedRoutingRules(context.Background(), "global", "") assert.Equal(t, 1, len(retrieved)) // Delete rule - err = store.DeleteRoutingRuleInMemory(rule.ID) + err = store.DeleteRoutingRuleInMemory(context.Background(), rule.ID) require.NoError(t, err) // Verify deletion - retrieved = store.GetScopedRoutingRules("global", "") + retrieved = store.GetScopedRoutingRules(context.Background(), "global", "") assert.Equal(t, 0, len(retrieved)) } @@ -994,27 +994,27 @@ func TestGovernanceStore_RoutingRules_MultipleScopes(t *testing.T) { ID: "3", Name: "Team", Scope: "team", ScopeID: &teamID, Priority: 30, Enabled: true, } - require.NoError(t, store.UpdateRoutingRuleInMemory(globalRule)) - require.NoError(t, store.UpdateRoutingRuleInMemory(customerRule)) - require.NoError(t, store.UpdateRoutingRuleInMemory(teamRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), globalRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), customerRule)) + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), teamRule)) // Test global scope - globalRules := store.GetScopedRoutingRules("global", "") + globalRules := store.GetScopedRoutingRules(context.Background(), "global", "") assert.Equal(t, 1, len(globalRules)) assert.Equal(t, "Global", globalRules[0].Name) // Test customer scope - custRules := store.GetScopedRoutingRules("customer", customerID) + custRules := store.GetScopedRoutingRules(context.Background(), "customer", customerID) assert.Equal(t, 1, len(custRules)) assert.Equal(t, "Customer", custRules[0].Name) // Test team scope - teamRules := store.GetScopedRoutingRules("team", teamID) + teamRules := store.GetScopedRoutingRules(context.Background(), "team", teamID) assert.Equal(t, 1, len(teamRules)) assert.Equal(t, "Team", teamRules[0].Name) // ListAll should return all rules sorted by priority ASC (lower numbers = higher priority) - allRules := store.GetAllRoutingRules() + allRules := store.GetAllRoutingRules(context.Background()) assert.Equal(t, 3, len(allRules)) assert.Equal(t, 10, allRules[0].Priority) // Global (highest) assert.Equal(t, 20, allRules[1].Priority) // Customer @@ -1038,12 +1038,12 @@ func TestCompileAndCacheProgram(t *testing.T) { } // First compilation - program1, err := store.GetRoutingProgram(rule) + program1, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program1) // Verify it's cached - second call should return cached program - program2, err := store.GetRoutingProgram(rule) + program2, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program2) @@ -1067,11 +1067,11 @@ func TestCompileAndCacheProgram_InvalidExpression(t *testing.T) { Enabled: true, } - _, err = store.GetRoutingProgram(rule) + _, err = store.GetRoutingProgram(context.Background(), rule) assert.Error(t, err) // Invalid rule should not be cached - attempting to get it again should fail - _, err = store.GetRoutingProgram(rule) + _, err = store.GetRoutingProgram(context.Background(), rule) assert.Error(t, err) } @@ -1093,17 +1093,17 @@ func TestCompileAndCacheProgram_CacheInvalidation(t *testing.T) { } // Compile and cache - program1, err := store.GetRoutingProgram(rule) + program1, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program1) // Update rule in memory (should invalidate cache) rule.CelExpression = "model == 'gpt-4-turbo'" - err = store.UpdateRoutingRuleInMemory(rule) + err = store.UpdateRoutingRuleInMemory(context.Background(), rule) require.NoError(t, err) // Recompile should work - program2, err := store.GetRoutingProgram(rule) + program2, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program2) } @@ -1126,11 +1126,11 @@ func TestCompileAndCacheProgram_CacheInvalidationOnDelete(t *testing.T) { } // Compile and cache - _, err = store.GetRoutingProgram(rule) + _, err = store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) // Delete rule (should invalidate cache) - err = store.DeleteRoutingRuleInMemory(rule.ID) + err = store.DeleteRoutingRuleInMemory(context.Background(), rule.ID) require.NoError(t, err) // After deletion, we can't verify cache directly, but the rule is gone from storage @@ -1152,12 +1152,12 @@ func TestCompileAndCacheProgram_EmptyExpression(t *testing.T) { Enabled: true, } - program, err := store.GetRoutingProgram(rule) + program, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program) // Verify caching works - second call should return same program - program2, err := store.GetRoutingProgram(rule) + program2, err := store.GetRoutingProgram(context.Background(), rule) require.NoError(t, err) assert.NotNil(t, program2) assert.Equal(t, program, program2) diff --git a/plugins/governance/tracker.go b/plugins/governance/tracker.go index f2460bad6d..8824910174 100644 --- a/plugins/governance/tracker.go +++ b/plugins/governance/tracker.go @@ -118,7 +118,7 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { } // Get virtual key - vk, exists := t.store.GetVirtualKey(update.VirtualKey) + vk, exists := t.store.GetVirtualKey(ctx, update.VirtualKey) if !exists { t.logger.Debug(fmt.Sprintf("Virtual key not found: %s", update.VirtualKey)) return diff --git a/plugins/governance/tracker_test.go b/plugins/governance/tracker_test.go index 6af947d0fa..dc5cd71940 100644 --- a/plugins/governance/tracker_test.go +++ b/plugins/governance/tracker_test.go @@ -45,7 +45,7 @@ func TestUsageTracker_UpdateUsage_FailedRequest(t *testing.T) { time.Sleep(200 * time.Millisecond) // Verify budget was NOT updated - retrieve from store - budgets := store.GetGovernanceData().Budgets + budgets := store.GetGovernanceData(context.Background()).Budgets updatedBudget, exists := budgets["budget1"] require.True(t, exists) require.NotNil(t, updatedBudget) @@ -116,7 +116,7 @@ func TestUsageTracker_UpdateUsage_StreamingOptimization(t *testing.T) { time.Sleep(200 * time.Millisecond) // Retrieve the updated rate limit from the main RateLimits map - governanceData := store.GetGovernanceData() + governanceData := store.GetGovernanceData(context.Background()) updatedRateLimit, exists := governanceData.RateLimits["rl1"] require.True(t, exists, "Rate limit should exist") require.NotNil(t, updatedRateLimit) @@ -142,7 +142,7 @@ func TestUsageTracker_UpdateUsage_StreamingOptimization(t *testing.T) { time.Sleep(200 * time.Millisecond) // Retrieve the updated rate limit again - governanceData = store.GetGovernanceData() + governanceData = store.GetGovernanceData(context.Background()) updatedRateLimit, exists = governanceData.RateLimits["rl1"] require.True(t, exists, "Rate limit should exist") require.NotNil(t, updatedRateLimit) diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go index 691d9ba6ca..260cb62e68 100644 --- a/plugins/governance/utils.go +++ b/plugins/governance/utils.go @@ -2,6 +2,7 @@ package governance import ( + "context" "strings" bifrost "github.com/maximhq/bifrost/core" @@ -89,11 +90,12 @@ func getWeight(w *float64) float64 { // filterModelsForVirtualKey filters models based on virtual key's provider configs // Returns only models that are allowed by the virtual key's ProviderConfigs func (p *GovernancePlugin) filterModelsForVirtualKey( + ctx context.Context, models []schemas.Model, virtualKeyValue string, ) []schemas.Model { // Get virtual key configuration - vk, exists := p.store.GetVirtualKey(virtualKeyValue) + vk, exists := p.store.GetVirtualKey(ctx, virtualKeyValue) if !exists { p.logger.Warn("[Governance] Virtual key not found for list models filtering: %s", virtualKeyValue) return []schemas.Model{} // VK not found, return empty list diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index 9fb5449b60..c5ad8de9ad 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -28,7 +28,7 @@ import ( // GovernanceManager is the interface for the governance manager type GovernanceManager interface { - GetGovernanceData() *governance.GovernanceData + GetGovernanceData(ctx context.Context) *governance.GovernanceData ReloadVirtualKey(ctx context.Context, id string) (*configstoreTables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error ReloadTeam(ctx context.Context, id string) (*configstoreTables.TableTeam, error) @@ -334,7 +334,7 @@ func (h *GovernanceHandler) getVirtualKeys(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -682,7 +682,7 @@ func (h *GovernanceHandler) getVirtualKey(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -1318,7 +1318,7 @@ func (h *GovernanceHandler) getTeams(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -1498,7 +1498,7 @@ func (h *GovernanceHandler) getTeam(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -1740,7 +1740,7 @@ func (h *GovernanceHandler) getCustomers(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -1897,7 +1897,7 @@ func (h *GovernanceHandler) getCustomer(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -2132,7 +2132,7 @@ func (h *GovernanceHandler) getBudgets(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -2160,7 +2160,7 @@ func (h *GovernanceHandler) getRateLimits(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -2248,7 +2248,7 @@ func validateBudget(budget *configstoreTables.TableBudget) error { func (h *GovernanceHandler) getModelConfigs(ctx *fasthttp.RequestCtx) { fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -2672,7 +2672,7 @@ type ProviderGovernanceResponse struct { func (h *GovernanceHandler) getProviderGovernance(ctx *fasthttp.RequestCtx) { fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - data := h.governanceManager.GetGovernanceData() + data := h.governanceManager.GetGovernanceData(ctx) if data == nil { SendError(ctx, 500, "Governance data is not available") return @@ -2967,7 +2967,7 @@ func (h *GovernanceHandler) getRoutingRules(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - gd := h.governanceManager.GetGovernanceData() + gd := h.governanceManager.GetGovernanceData(ctx) if gd == nil { SendError(ctx, 500, "Governance data is not available") return @@ -3101,7 +3101,7 @@ func (h *GovernanceHandler) getRoutingRule(ctx *fasthttp.RequestCtx) { // Check if "from_memory" query parameter is set to true fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" if fromMemory { - gd := h.governanceManager.GetGovernanceData() + gd := h.governanceManager.GetGovernanceData(ctx) if gd == nil { SendError(ctx, 500, "Governance data is not available") return diff --git a/transports/bifrost-http/handlers/governance_test.go b/transports/bifrost-http/handlers/governance_test.go index 581a22e7b1..7a11fa9770 100644 --- a/transports/bifrost-http/handlers/governance_test.go +++ b/transports/bifrost-http/handlers/governance_test.go @@ -18,7 +18,7 @@ type mockGovernanceManagerForVK struct { GovernanceManager } -func (m *mockGovernanceManagerForVK) GetGovernanceData() *governance.GovernanceData { +func (m *mockGovernanceManagerForVK) GetGovernanceData(ctx context.Context) *governance.GovernanceData { return nil } diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 3b7e60a34e..64ce383eb9 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -124,7 +124,7 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { // Build VK id→name lookup from in-memory governance data vkNameByID := make(map[string]string) if h.governanceManager != nil { - if gd := h.governanceManager.GetGovernanceData(); gd != nil { + if gd := h.governanceManager.GetGovernanceData(ctx); gd != nil { for _, vk := range gd.VirtualKeys { vkNameByID[vk.ID] = vk.Name } @@ -250,7 +250,7 @@ func (h *MCPHandler) getMCPClientsPaginated(ctx *fasthttp.RequestCtx, limitStr, // Build VK id→name lookup from in-memory governance data (no extra DB queries) vkNameByID := make(map[string]string) if h.governanceManager != nil { - if gd := h.governanceManager.GetGovernanceData(); gd != nil { + if gd := h.governanceManager.GetGovernanceData(ctx); gd != nil { for _, vk := range gd.VirtualKeys { vkNameByID[vk.ID] = vk.Name } diff --git a/transports/bifrost-http/handlers/pricing_override_test.go b/transports/bifrost-http/handlers/pricing_override_test.go index 4d19d0541e..adc2c63d3b 100644 --- a/transports/bifrost-http/handlers/pricing_override_test.go +++ b/transports/bifrost-http/handlers/pricing_override_test.go @@ -20,7 +20,7 @@ import ( type pricingOverrideTestGovernanceManager struct{} -func (pricingOverrideTestGovernanceManager) GetGovernanceData() *governance.GovernanceData { +func (pricingOverrideTestGovernanceManager) GetGovernanceData(ctx context.Context) *governance.GovernanceData { return nil } func (pricingOverrideTestGovernanceManager) ReloadVirtualKey(context.Context, string) (*configstoreTables.TableVirtualKey, error) { diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 0bc470db23..3835c9f192 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -372,7 +372,6 @@ import ( "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/logstore" - "github.com/maximhq/bifrost/framework/migrator" "github.com/maximhq/bifrost/framework/modelcatalog" "github.com/maximhq/bifrost/framework/vectorstore" "github.com/stretchr/testify/require" @@ -418,6 +417,9 @@ func NewMockConfigStore() *MockConfigStore { } // Implement ConfigStore interface methods +func (m *MockConfigStore) RefreshConnectionPool(ctx context.Context) error { + return nil +} func (m *MockConfigStore) Ping(ctx context.Context) error { return nil } func (m *MockConfigStore) EncryptPlaintextRows(ctx context.Context) error { return nil } func (m *MockConfigStore) Close(ctx context.Context) error { return nil } @@ -426,7 +428,7 @@ func (m *MockConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *go return fn(nil) } -func (m *MockConfigStore) RunMigration(ctx context.Context, migration *migrator.Migration) error { +func (m *MockConfigStore) RunMigration(context.Context, func(context.Context, *gorm.DB) error) error { return nil } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 5c7a3ad679..53a13a8b6c 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -70,7 +70,7 @@ type ServerCallbacks interface { ReloadHeaderFilterConfig(ctx context.Context, config *tables.GlobalHeaderFilterConfig) error UpdateDropExcessRequests(ctx context.Context, value bool) // Governance related callbacks - GetGovernanceData() *governance.GovernanceData + GetGovernanceData(ctx context.Context) *governance.GovernanceData ReloadTeam(ctx context.Context, id string) (*tables.TableTeam, error) RemoveTeam(ctx context.Context, id string) error ReloadCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) @@ -329,7 +329,7 @@ func (s *BifrostHTTPServer) ReloadVirtualKey(ctx context.Context, id string) (*t if err != nil { return nil, err } - governancePlugin.GetGovernanceStore().UpdateVirtualKeyInMemory(virtualKey, nil, nil, nil) + governancePlugin.GetGovernanceStore().UpdateVirtualKeyInMemory(ctx, virtualKey, nil, nil, nil) s.MCPServerHandler.SyncVKMCPServer(virtualKey) return virtualKey, nil } @@ -348,10 +348,10 @@ func (s *BifrostHTTPServer) RemoveVirtualKey(ctx context.Context, id string) err } if preloadedVk == nil { // This could be broadcast message from other server, so we will just clean up in-memory store - governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(id) + governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(ctx, id) return nil } - governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(id) + governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(ctx, id) s.MCPServerHandler.DeleteVKMCPServer(preloadedVk.Value) return nil } @@ -369,7 +369,7 @@ func (s *BifrostHTTPServer) ReloadTeam(ctx context.Context, id string) (*tables. return nil, err } // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateTeamInMemory(preloadedTeam, nil) + governancePlugin.GetGovernanceStore().UpdateTeamInMemory(ctx, preloadedTeam, nil) return preloadedTeam, nil } @@ -387,10 +387,10 @@ func (s *BifrostHTTPServer) RemoveTeam(ctx context.Context, id string) error { } if preloadedTeam == nil { // At-least deleting from in-memory store to avoid conflicts - governancePlugin.GetGovernanceStore().DeleteTeamInMemory(id) + governancePlugin.GetGovernanceStore().DeleteTeamInMemory(ctx, id) return nil } - governancePlugin.GetGovernanceStore().DeleteTeamInMemory(id) + governancePlugin.GetGovernanceStore().DeleteTeamInMemory(ctx, id) return nil } @@ -405,7 +405,7 @@ func (s *BifrostHTTPServer) ReloadCustomer(ctx context.Context, id string) (*tab return nil, err } // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateCustomerInMemory(preloadedCustomer, nil) + governancePlugin.GetGovernanceStore().UpdateCustomerInMemory(ctx, preloadedCustomer, nil) return preloadedCustomer, nil } @@ -423,10 +423,10 @@ func (s *BifrostHTTPServer) RemoveCustomer(ctx context.Context, id string) error } if preloadedCustomer == nil { // At-least deleting from in-memory store to avoid conflicts - governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(id) + governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(ctx, id) return nil } - governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(id) + governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(ctx, id) return nil } @@ -443,7 +443,7 @@ func (s *BifrostHTTPServer) ReloadModelConfig(ctx context.Context, id string) (* return nil, err } // Update in memory and get back the potentially modified model config - updatedMC := governancePlugin.GetGovernanceStore().UpdateModelConfigInMemory(preloadedMC) + updatedMC := governancePlugin.GetGovernanceStore().UpdateModelConfigInMemory(ctx, preloadedMC) if updatedMC == nil { return preloadedMC, nil } @@ -475,7 +475,7 @@ func (s *BifrostHTTPServer) RemoveModelConfig(ctx context.Context, id string) er if err != nil { return err } - governancePlugin.GetGovernanceStore().DeleteModelConfigInMemory(id) + governancePlugin.GetGovernanceStore().DeleteModelConfigInMemory(ctx, id) return nil } @@ -507,7 +507,7 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas logger.Warn("governance plugin found but failed to get: %v", err) } else { // Update in memory and get back the potentially modified provider - govUpdated := governancePlugin.GetGovernanceStore().UpdateProviderInMemory(providerInfo) + govUpdated := governancePlugin.GetGovernanceStore().UpdateProviderInMemory(ctx, providerInfo) if govUpdated != nil { updatedProvider = govUpdated } @@ -630,7 +630,7 @@ func (s *BifrostHTTPServer) RemoveProvider(ctx context.Context, provider schemas if err != nil { return err } - governancePlugin.GetGovernanceStore().DeleteProviderInMemory(string(provider)) + governancePlugin.GetGovernanceStore().DeleteProviderInMemory(ctx, string(provider)) if s.Config == nil || s.Config.ModelCatalog == nil { return fmt.Errorf("pricing manager not found") } @@ -640,14 +640,13 @@ func (s *BifrostHTTPServer) RemoveProvider(ctx context.Context, provider schemas } // GetGovernanceData returns the governance data -func (s *BifrostHTTPServer) GetGovernanceData() *governance.GovernanceData { +func (s *BifrostHTTPServer) GetGovernanceData(ctx context.Context) *governance.GovernanceData { // Use type-safe finder from Config governancePlugin, err := lib.FindPluginAs[governance.BaseGovernancePlugin](s.Config, s.getGovernancePluginName()) if err != nil { return nil } - - return governancePlugin.GetGovernanceStore().GetGovernanceData() + return governancePlugin.GetGovernanceStore().GetGovernanceData(ctx) } // ReloadRoutingRule reloads a routing rule from the database into the governance store @@ -667,7 +666,7 @@ func (s *BifrostHTTPServer) ReloadRoutingRule(ctx context.Context, id string) er return fmt.Errorf("failed to get routing rule from config store: %w", err) } // Update the rule in the store (this updates the in-memory cache) - if err := store.UpdateRoutingRuleInMemory(rule); err != nil { + if err := store.UpdateRoutingRuleInMemory(ctx, rule); err != nil { return fmt.Errorf("failed to update routing rule in store: %w", err) } return nil @@ -686,7 +685,7 @@ func (s *BifrostHTTPServer) RemoveRoutingRule(ctx context.Context, id string) er // Get the governance store from the plugin store := governancePlugin.GetGovernanceStore() // Delete the rule from the store (this removes from in-memory cache) - if err := store.DeleteRoutingRuleInMemory(id); err != nil { + if err := store.DeleteRoutingRuleInMemory(ctx, id); err != nil { return fmt.Errorf("failed to delete routing rule from store: %w", err) } return nil diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx index 502bf88d9f..e246c8aa39 100644 --- a/ui/app/workspace/config/views/mcpView.tsx +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -1,258 +1,334 @@ import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { Switch } from "@/components/ui/switch"; -import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { + getErrorMessage, + useGetCoreConfigQuery, + useUpdateCoreConfigMutation, +} from "@/lib/store"; import { CoreConfig, DefaultCoreConfig } from "@/lib/types/config"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; import { useCallback, useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; export default function MCPView() { - const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); - const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); - const config = bifrostConfig?.client_config; - const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); - const [localConfig, setLocalConfig] = useState(DefaultCoreConfig); + const hasSettingsUpdateAccess = useRbac( + RbacResource.Settings, + RbacOperation.Update, + ); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [localConfig, setLocalConfig] = useState(DefaultCoreConfig); - const [localValues, setLocalValues] = useState<{ - mcp_agent_depth: string; - mcp_tool_execution_timeout: string; - mcp_code_mode_binding_level: string; - mcp_tool_sync_interval: string; - }>({ - mcp_agent_depth: "10", - mcp_tool_execution_timeout: "30", - mcp_code_mode_binding_level: "server", - mcp_tool_sync_interval: "10", - }); + const [localValues, setLocalValues] = useState<{ + mcp_agent_depth: string; + mcp_tool_execution_timeout: string; + mcp_code_mode_binding_level: string; + mcp_tool_sync_interval: string; + }>({ + mcp_agent_depth: "10", + mcp_tool_execution_timeout: "30", + mcp_code_mode_binding_level: "server", + mcp_tool_sync_interval: "10", + }); - useEffect(() => { - if (bifrostConfig && config) { - setLocalConfig(config); - setLocalValues({ - mcp_agent_depth: config?.mcp_agent_depth?.toString() || "10", - mcp_tool_execution_timeout: config?.mcp_tool_execution_timeout?.toString() || "30", - mcp_code_mode_binding_level: config?.mcp_code_mode_binding_level || "server", - mcp_tool_sync_interval: config?.mcp_tool_sync_interval?.toString() || "10", - }); - } - }, [config, bifrostConfig]); + useEffect(() => { + if (bifrostConfig && config) { + setLocalConfig(config); + setLocalValues({ + mcp_agent_depth: config?.mcp_agent_depth?.toString() || "10", + mcp_tool_execution_timeout: + config?.mcp_tool_execution_timeout?.toString() || "30", + mcp_code_mode_binding_level: + config?.mcp_code_mode_binding_level || "server", + mcp_tool_sync_interval: + config?.mcp_tool_sync_interval?.toString() || "10", + }); + } + }, [config, bifrostConfig]); - const hasChanges = useMemo(() => { - if (!config) return false; - return ( - localConfig.mcp_agent_depth !== config.mcp_agent_depth || - localConfig.mcp_tool_execution_timeout !== config.mcp_tool_execution_timeout || - localConfig.mcp_code_mode_binding_level !== (config.mcp_code_mode_binding_level || "server") || - localConfig.mcp_tool_sync_interval !== (config.mcp_tool_sync_interval ?? 10) || - localConfig.mcp_disable_auto_tool_inject !== (config.mcp_disable_auto_tool_inject ?? false) - ); - }, [config, localConfig]); + const hasChanges = useMemo(() => { + if (!config) return false; + return ( + localConfig.mcp_agent_depth !== config.mcp_agent_depth || + localConfig.mcp_tool_execution_timeout !== + config.mcp_tool_execution_timeout || + localConfig.mcp_code_mode_binding_level !== + (config.mcp_code_mode_binding_level || "server") || + localConfig.mcp_tool_sync_interval !== + (config.mcp_tool_sync_interval ?? 10) || + localConfig.mcp_disable_auto_tool_inject !== + (config.mcp_disable_auto_tool_inject ?? false) + ); + }, [config, localConfig]); - const handleAgentDepthChange = useCallback((value: string) => { - setLocalValues((prev) => ({ ...prev, mcp_agent_depth: value })); - const numValue = Number.parseInt(value); - if (!isNaN(numValue) && numValue > 0) { - setLocalConfig((prev) => ({ ...prev, mcp_agent_depth: numValue })); - } - }, []); + const handleAgentDepthChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_agent_depth: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_agent_depth: numValue })); + } + }, []); - const handleToolExecutionTimeoutChange = useCallback((value: string) => { - setLocalValues((prev) => ({ ...prev, mcp_tool_execution_timeout: value })); - const numValue = Number.parseInt(value); - if (!isNaN(numValue) && numValue > 0) { - setLocalConfig((prev) => ({ ...prev, mcp_tool_execution_timeout: numValue })); - } - }, []); + const handleToolExecutionTimeoutChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_tool_execution_timeout: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ + ...prev, + mcp_tool_execution_timeout: numValue, + })); + } + }, []); - const handleCodeModeBindingLevelChange = useCallback((value: string) => { - setLocalValues((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); - if (value === "server" || value === "tool") { - setLocalConfig((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); - } - }, []); + const handleCodeModeBindingLevelChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); + if (value === "server" || value === "tool") { + setLocalConfig((prev) => ({ + ...prev, + mcp_code_mode_binding_level: value, + })); + } + }, []); - const handleToolSyncIntervalChange = useCallback((value: string) => { - setLocalValues((prev) => ({ ...prev, mcp_tool_sync_interval: value })); - const numValue = Number.parseInt(value); - if (!isNaN(numValue) && numValue >= 0) { - setLocalConfig((prev) => ({ ...prev, mcp_tool_sync_interval: numValue })); - } - }, []); + const handleToolSyncIntervalChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_tool_sync_interval: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue >= 0) { + setLocalConfig((prev) => ({ ...prev, mcp_tool_sync_interval: numValue })); + } + }, []); - const handleDisableAutoToolInjectChange = useCallback((checked: boolean) => { - setLocalConfig((prev) => ({ ...prev, mcp_disable_auto_tool_inject: checked })); - }, []); + const handleDisableAutoToolInjectChange = useCallback((checked: boolean) => { + setLocalConfig((prev) => ({ + ...prev, + mcp_disable_auto_tool_inject: checked, + })); + }, []); - const handleSave = useCallback(async () => { - try { - const agentDepth = Number.parseInt(localValues.mcp_agent_depth); - const toolTimeout = Number.parseInt(localValues.mcp_tool_execution_timeout); + const handleSave = useCallback(async () => { + try { + const agentDepth = Number.parseInt(localValues.mcp_agent_depth); + const toolTimeout = Number.parseInt( + localValues.mcp_tool_execution_timeout, + ); - if (isNaN(agentDepth) || agentDepth <= 0) { - toast.error("Max agent depth must be a positive number."); - return; - } + if (isNaN(agentDepth) || agentDepth <= 0) { + toast.error("Max agent depth must be a positive number."); + return; + } - if (isNaN(toolTimeout) || toolTimeout <= 0) { - toast.error("Tool execution timeout must be a positive number."); - return; - } + if (isNaN(toolTimeout) || toolTimeout <= 0) { + toast.error("Tool execution timeout must be a positive number."); + return; + } - if (!bifrostConfig) { - toast.error("Configuration not loaded. Please refresh and try again."); - return; - } - await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); - toast.success("MCP settings updated successfully."); - } catch (error) { - toast.error(getErrorMessage(error)); - } - }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ + ...bifrostConfig, + client_config: localConfig, + }).unwrap(); + toast.success("MCP settings updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); - return ( -
-
-

MCP Settings

-

Configure MCP (Model Context Protocol) agent and tool settings.

-
-
- {/* Max Agent Depth */} -
-
- -

Maximum depth for MCP agent execution.

-
- handleAgentDepthChange(e.target.value)} - min="1" - /> -
+ return ( +
+
+

MCP Settings

+

+ Configure MCP (Model Context Protocol) agent and tool settings. +

+
+
+ {/* Max Agent Depth */} +
+
+ +

+ Maximum depth for MCP agent execution. +

+
+ handleAgentDepthChange(e.target.value)} + min="1" + /> +
- {/* Tool Execution Timeout */} -
-
- -

Maximum time in seconds for tool execution.

-
- handleToolExecutionTimeoutChange(e.target.value)} - min="1" - /> -
+ {/* Tool Execution Timeout */} +
+
+ +

+ Maximum time in seconds for tool execution. +

+
+ handleToolExecutionTimeoutChange(e.target.value)} + min="1" + /> +
- {/* Tool Sync Interval */} -
-
- -

How often to refresh tool lists from MCP servers. Set to 0 to disable.

-
- handleToolSyncIntervalChange(e.target.value)} - min="0" - /> -
+ {/* Tool Sync Interval */} +
+
+ +

+ How often to refresh tool lists from MCP servers. Set to 0 to + disable. +

+
+ handleToolSyncIntervalChange(e.target.value)} + min="0" + /> +
- {/* Disable Auto Tool Injection */} -
-
- -

- When enabled, MCP tools are not automatically included in every request. Tools are only injected when explicitly specified via - request headers (x-bf-mcp-include-tools) and still must be allowed by the virtual key MCP - configuration. -

-
- -
+ {/* Disable Auto Tool Injection */} +
+
+ +

+ When enabled, MCP tools are not automatically included in every + request. Tools are only injected when explicitly specified via + request headers ( + x-bf-mcp-include-tools) and still + must be allowed by the virtual key MCP configuration. +

+
+ +
- {/* Code Mode Binding Level */} -
-
- -

- How tools are exposed in the VFS: server-level (all tools per server) or tool-level (individual tools). -

-
- + {/* Code Mode Binding Level */} +
+
+ +

+ How tools are exposed in the VFS: server-level (all tools per + server) or tool-level (individual tools). +

+
+ - {/* Visual Example */} -
-

VFS Structure:

+ {/* Visual Example */} +
+

+ VFS Structure: +

- {localValues.mcp_code_mode_binding_level === "server" ? ( -
-
-
servers/
-
├─ calculator.d.ts
-
├─ youtube.d.ts
-
└─ weather.d.ts
-
-

All tools per server in a single .d.ts file

-
- ) : ( -
-
-
servers/
-
├─ calculator/
-
├─ add.d.ts
-
└─ subtract.d.ts
-
├─ youtube/
-
├─ GET_CHANNELS.d.ts
-
└─ SEARCH_VIDEOS.d.ts
-
└─ weather/
-
└─ get_forecast.d.ts
-
-

Individual .d.ts file for each tool

-
- )} -
-
-
-
- -
-
- ); -} \ No newline at end of file + {localValues.mcp_code_mode_binding_level === "server" ? ( +
+
+
servers/
+
├─ calculator.py
+
├─ youtube.py
+
└─ weather.py
+
+

+ All tools per server in a single .py file +

+
+ ) : ( +
+
+
servers/
+
├─ calculator/
+
├─ add.py
+
└─ subtract.py
+
├─ youtube/
+
├─ GET_CHANNELS.py
+
└─ SEARCH_VIDEOS.py
+
└─ weather/
+
└─ get_forecast.py
+
+

+ Individual .py file for each tool +

+
+ )} +
+
+
+
+ +
+
+ ); +} diff --git a/ui/app/workspace/governance/views/teamDialog.tsx b/ui/app/workspace/governance/views/teamDialog.tsx index febf7c8034..6d282d1229 100644 --- a/ui/app/workspace/governance/views/teamDialog.tsx +++ b/ui/app/workspace/governance/views/teamDialog.tsx @@ -1,24 +1,48 @@ import FormFooter from "@/components/formFooter"; import { Badge } from "@/components/ui/badge"; import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, } from "@/components/ui/alertDialog"; -import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/components/ui/dialog"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import NumberAndSelect from "@/components/ui/numberAndSelect"; -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { Switch } from "@/components/ui/switch"; -import { resetDurationOptions, supportsCalendarAlignment } from "@/lib/constants/governance"; -import { getErrorMessage, useCreateTeamMutation, useUpdateTeamMutation } from "@/lib/store"; -import { CreateTeamRequest, Customer, Team, UpdateTeamRequest } from "@/lib/types/governance"; +import { + resetDurationOptions, + supportsCalendarAlignment, +} from "@/lib/constants/governance"; +import { + getErrorMessage, + useCreateTeamMutation, + useUpdateTeamMutation, +} from "@/lib/store"; +import { + CreateTeamRequest, + Customer, + Team, + UpdateTeamRequest, +} from "@/lib/types/governance"; import { formatCurrency } from "@/lib/utils/governance"; import { Validator } from "@/lib/utils/validation"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; @@ -28,458 +52,610 @@ import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; interface TeamDialogProps { - team?: Team | null; - customers: Customer[]; - onSave: () => void; - onCancel: () => void; + team?: Team | null; + customers: Customer[]; + onSave: () => void; + onCancel: () => void; } interface TeamFormData { - name: string; - customerId: string; - // Budget - budgetMaxLimit: number | undefined; - budgetResetDuration: string; - budgetCalendarAligned: boolean; - // Rate Limit - tokenMaxLimit: number | undefined; - tokenResetDuration: string; - requestMaxLimit: number | undefined; - requestResetDuration: string; - isDirty: boolean; + name: string; + customerId: string; + // Budget + budgetMaxLimit: number | undefined; + budgetResetDuration: string; + budgetCalendarAligned: boolean; + // Rate Limit + tokenMaxLimit: number | undefined; + tokenResetDuration: string; + requestMaxLimit: number | undefined; + requestResetDuration: string; + isDirty: boolean; } // Helper function to create initial state -const createInitialState = (team?: Team | null): Omit => { - return { - name: team?.name || "", - customerId: team?.customer_id || "", - // Budget - budgetMaxLimit: team?.budget?.max_limit ?? undefined, - budgetResetDuration: team?.budget?.reset_duration || "1M", - budgetCalendarAligned: team?.budget?.calendar_aligned ?? false, - // Rate Limit - tokenMaxLimit: team?.rate_limit?.token_max_limit ?? undefined, - tokenResetDuration: team?.rate_limit?.token_reset_duration || "1h", - requestMaxLimit: team?.rate_limit?.request_max_limit ?? undefined, - requestResetDuration: team?.rate_limit?.request_reset_duration || "1h", - }; +const createInitialState = ( + team?: Team | null, +): Omit => { + return { + name: team?.name || "", + customerId: team?.customer_id || "", + // Budget + budgetMaxLimit: team?.budget?.max_limit ?? undefined, + budgetResetDuration: team?.budget?.reset_duration || "1M", + budgetCalendarAligned: team?.budget?.calendar_aligned ?? false, + // Rate Limit + tokenMaxLimit: team?.rate_limit?.token_max_limit ?? undefined, + tokenResetDuration: team?.rate_limit?.token_reset_duration || "1h", + requestMaxLimit: team?.rate_limit?.request_max_limit ?? undefined, + requestResetDuration: team?.rate_limit?.request_reset_duration || "1h", + }; }; -export default function TeamDialog({ team, customers, onSave, onCancel }: TeamDialogProps) { - const isEditing = !!team; - const [initialState] = useState>(createInitialState(team)); - const [formData, setFormData] = useState({ - ...initialState, - isDirty: false, - }); - - const hasCreateAccess = useRbac(RbacResource.Teams, RbacOperation.Create); - const hasUpdateAccess = useRbac(RbacResource.Teams, RbacOperation.Update); - const hasPermission = isEditing ? hasUpdateAccess : hasCreateAccess; - - // RTK Query hooks - const [createTeam, { isLoading: isCreating }] = useCreateTeamMutation(); - const [updateTeam, { isLoading: isUpdating }] = useUpdateTeamMutation(); - const loading = isCreating || isUpdating; - - const [showCalendarAlignWarning, setShowCalendarAlignWarning] = useState(false); - - const handleCalendarAlignedChange = (checked: boolean) => { - if (checked && isEditing && team?.budget && !team.budget.calendar_aligned) { - setShowCalendarAlignWarning(true); - } else { - updateField("budgetCalendarAligned", checked); - } - }; - - // Track isDirty state - useEffect(() => { - const currentData = { - name: formData.name, - customerId: formData.customerId, - budgetMaxLimit: formData.budgetMaxLimit, - budgetResetDuration: formData.budgetResetDuration, - budgetCalendarAligned: formData.budgetCalendarAligned, - tokenMaxLimit: formData.tokenMaxLimit, - tokenResetDuration: formData.tokenResetDuration, - requestMaxLimit: formData.requestMaxLimit, - requestResetDuration: formData.requestResetDuration, - }; - setFormData((prev) => ({ - ...prev, - isDirty: !isEqual(initialState, currentData), - })); - }, [ - formData.name, - formData.customerId, - formData.budgetMaxLimit, - formData.budgetResetDuration, - formData.budgetCalendarAligned, - formData.tokenMaxLimit, - formData.tokenResetDuration, - formData.requestMaxLimit, - formData.requestResetDuration, - initialState, - ]); - - // Values for validation and submission (already numbers) - const budgetMaxLimitNum = formData.budgetMaxLimit; - const tokenMaxLimitNum = formData.tokenMaxLimit; - const requestMaxLimitNum = formData.requestMaxLimit; - - // Validation - const validator = useMemo( - () => - new Validator([ - // Basic validation - Validator.required(formData.name.trim(), "Team name is required"), - - // Check if anything is dirty - Validator.custom(formData.isDirty, "No changes to save"), - - // Budget validation - ...(formData.budgetMaxLimit - ? [ - Validator.minValue(budgetMaxLimitNum || 0, 0.01, "Budget max limit must be greater than $0.01"), - Validator.required(formData.budgetResetDuration, "Budget reset duration is required"), - ] - : []), - - // Rate limit validation - token limits - ...(formData.tokenMaxLimit - ? [ - Validator.minValue(tokenMaxLimitNum || 0, 1, "Token max limit must be at least 1"), - Validator.required(formData.tokenResetDuration, "Token reset duration is required"), - ] - : []), - - // Rate limit validation - request limits - ...(formData.requestMaxLimit - ? [ - Validator.minValue(requestMaxLimitNum || 0, 1, "Request max limit must be at least 1"), - Validator.required(formData.requestResetDuration, "Request reset duration is required"), - ] - : []), - ]), - [formData, budgetMaxLimitNum, tokenMaxLimitNum, requestMaxLimitNum], - ); - - const updateField = (field: K, value: TeamFormData[K]) => { - setFormData((prev) => ({ ...prev, [field]: value })); - }; - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - - if (!validator.isValid()) { - toast.error(validator.getFirstError()); - return; - } - - try { - if (isEditing && team) { - // Update existing team - const updateData: UpdateTeamRequest = { - name: formData.name, - customer_id: formData.customerId, - }; - - // Detect budget changes using had/has pattern - const hadBudget = !!team.budget; - const hasBudget = budgetMaxLimitNum !== undefined && budgetMaxLimitNum !== null; - if (hasBudget) { - updateData.budget = { - max_limit: budgetMaxLimitNum, - reset_duration: formData.budgetResetDuration, - calendar_aligned: formData.budgetCalendarAligned, - }; - } else if (hadBudget) { - updateData.budget = {} as UpdateTeamRequest["budget"]; - } - - // Detect rate limit changes using had/has pattern - const hadRateLimit = !!team.rate_limit; - const hasRateLimit = - (tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null) || - (requestMaxLimitNum !== undefined && requestMaxLimitNum !== null); - if (hasRateLimit) { - updateData.rate_limit = { - token_max_limit: tokenMaxLimitNum, - token_reset_duration: tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null ? formData.tokenResetDuration : undefined, - request_max_limit: requestMaxLimitNum, - request_reset_duration: - requestMaxLimitNum !== undefined && requestMaxLimitNum !== null ? formData.requestResetDuration : undefined, - }; - } else if (hadRateLimit) { - updateData.rate_limit = {} as UpdateTeamRequest["rate_limit"]; - } - - await updateTeam({ teamId: team.id, data: updateData }).unwrap(); - toast.success("Team updated successfully"); - } else { - // Create new team - const createData: CreateTeamRequest = { - name: formData.name, - customer_id: formData.customerId || undefined, - }; - - // Add budget if enabled - if (budgetMaxLimitNum !== undefined && budgetMaxLimitNum !== null) { - createData.budget = { - max_limit: budgetMaxLimitNum, - reset_duration: formData.budgetResetDuration, - calendar_aligned: formData.budgetCalendarAligned, - }; - } - - // Add rate limit if enabled (token or request limits) - if ( - (tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null) || - (requestMaxLimitNum !== undefined && requestMaxLimitNum !== null) - ) { - createData.rate_limit = { - token_max_limit: tokenMaxLimitNum, - token_reset_duration: tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null ? formData.tokenResetDuration : undefined, - request_max_limit: requestMaxLimitNum, - request_reset_duration: - requestMaxLimitNum !== undefined && requestMaxLimitNum !== null ? formData.requestResetDuration : undefined, - }; - } - - await createTeam(createData).unwrap(); - toast.success("Team created successfully"); - } - - onSave(); - } catch (error) { - toast.error(getErrorMessage(error)); - } - }; - - return ( - - - - {isEditing ? "Edit Team" : "Create Team"} - - {isEditing ? "Update the team information and settings." : "Create a new team to organize users and manage shared resources."} - - - -
-
- {/* Basic Information */} -
-
- - updateField("name", e.target.value)} - data-testid="team-name-input" - /> -
- - {/* Customer Assignment */} - {customers?.length > 0 && ( -
- - -

Assign to a customer or leave independent.

-
- )} -
- - {/* Budget Configuration */} - updateField("budgetMaxLimit", value)} - onChangeSelect={(value) => { - updateField("budgetResetDuration", value); - if (!supportsCalendarAlignment(value)) { - updateField("budgetCalendarAligned", false); - } - }} - options={resetDurationOptions} - dataTestId="budget-max-limit-input" - /> - - {/* Calendar alignment toggle — only shown when a budget is set and the period supports alignment */} - {formData.budgetMaxLimit && supportsCalendarAlignment(formData.budgetResetDuration) && ( -
-
- -

- Reset at the start of each period (e.g. 1st of month) instead of rolling from creation date -

-
- -
- )} - - {/* Warning dialog shown when enabling calendar alignment on an existing budget */} - - - - Reset budget usage? - - Enabling calendar alignment will reset this budget's current usage to $0.00{" "} - and snap the reset date to the start of the current{" "} - {formData.budgetResetDuration === "1d" - ? "day" - : formData.budgetResetDuration === "1w" - ? "week" - : formData.budgetResetDuration === "1M" - ? "month" - : formData.budgetResetDuration === "1Y" - ? "year" - : "period"} - . The usage reset to $0.00 cannot be undone, but calendar alignment can be turned off later. This will take effect when - you save. - - - - Cancel - { - updateField("budgetCalendarAligned", true); - setShowCalendarAlignWarning(false); - }} - > - Enable Calendar Alignment - - - - - - {/* Rate Limit Configuration - Token Limits */} - updateField("tokenMaxLimit", value)} - onChangeSelect={(value) => updateField("tokenResetDuration", value)} - options={resetDurationOptions} - /> - - {/* Rate Limit Configuration - Request Limits */} - updateField("requestMaxLimit", value)} - onChangeSelect={(value) => updateField("requestResetDuration", value)} - options={resetDurationOptions} - /> - - {/* Current Usage Section (only shown when editing with existing limits) */} - {isEditing && (team?.budget || team?.rate_limit) && ( -
-

Current Usage

-
- {team?.budget && ( -
-

Budget

-
- - {formatCurrency(team.budget.current_usage)} / {formatCurrency(team.budget.max_limit)} - - = team.budget.max_limit ? "destructive" : "default"} className="text-xs"> - {Math.round((team.budget.current_usage / team.budget.max_limit) * 100)}% - -
-

- Last Reset: {formatDistanceToNow(new Date(team.budget.last_reset), { addSuffix: true })} -

-
- )} - {team?.rate_limit?.token_max_limit && ( -
-

Tokens

-
- - {team.rate_limit.token_current_usage.toLocaleString()} / {team.rate_limit.token_max_limit.toLocaleString()} - - = team.rate_limit.token_max_limit ? "destructive" : "default"} - className="text-xs" - > - {Math.round((team.rate_limit.token_current_usage / team.rate_limit.token_max_limit) * 100)}% - -
-

- Last Reset: {formatDistanceToNow(new Date(team.rate_limit.token_last_reset), { addSuffix: true })} -

-
- )} - {team?.rate_limit?.request_max_limit && ( -
-

Requests

-
- - {team.rate_limit.request_current_usage.toLocaleString()} / {team.rate_limit.request_max_limit.toLocaleString()} - - = team.rate_limit.request_max_limit ? "destructive" : "default"} - className="text-xs" - > - {Math.round((team.rate_limit.request_current_usage / team.rate_limit.request_max_limit) * 100)}% - -
-

- Last Reset: {formatDistanceToNow(new Date(team.rate_limit.request_last_reset), { addSuffix: true })} -

-
- )} -
-
- )} -
- - - -
-
- ); -} \ No newline at end of file +export default function TeamDialog({ + team, + customers, + onSave, + onCancel, +}: TeamDialogProps) { + const isEditing = !!team; + const [initialState, setInitialState] = useState< + Omit + >(createInitialState(team)); + const [formData, setFormData] = useState({ + ...initialState, + isDirty: false, + }); + + useEffect(() => { + const nextInitial = createInitialState(team); + setInitialState(nextInitial); + setFormData({ ...nextInitial, isDirty: false }); + setShowCalendarAlignWarning(false); + }, [team]); + + const hasCreateAccess = useRbac(RbacResource.Teams, RbacOperation.Create); + const hasUpdateAccess = useRbac(RbacResource.Teams, RbacOperation.Update); + const hasPermission = isEditing ? hasUpdateAccess : hasCreateAccess; + + // RTK Query hooks + const [createTeam, { isLoading: isCreating }] = useCreateTeamMutation(); + const [updateTeam, { isLoading: isUpdating }] = useUpdateTeamMutation(); + const loading = isCreating || isUpdating; + + const [showCalendarAlignWarning, setShowCalendarAlignWarning] = + useState(false); + + const handleCalendarAlignedChange = (checked: boolean) => { + if (checked && isEditing && team?.budget && !team.budget.calendar_aligned) { + setShowCalendarAlignWarning(true); + } else { + updateField("budgetCalendarAligned", checked); + } + }; + + // Track isDirty state + useEffect(() => { + const currentData = { + name: formData.name, + customerId: formData.customerId, + budgetMaxLimit: formData.budgetMaxLimit, + budgetResetDuration: formData.budgetResetDuration, + budgetCalendarAligned: formData.budgetCalendarAligned, + tokenMaxLimit: formData.tokenMaxLimit, + tokenResetDuration: formData.tokenResetDuration, + requestMaxLimit: formData.requestMaxLimit, + requestResetDuration: formData.requestResetDuration, + }; + setFormData((prev) => ({ + ...prev, + isDirty: !isEqual(initialState, currentData), + })); + }, [ + formData.name, + formData.customerId, + formData.budgetMaxLimit, + formData.budgetResetDuration, + formData.budgetCalendarAligned, + formData.tokenMaxLimit, + formData.tokenResetDuration, + formData.requestMaxLimit, + formData.requestResetDuration, + initialState, + ]); + + // Values for validation and submission (already numbers) + const budgetMaxLimitNum = formData.budgetMaxLimit; + const tokenMaxLimitNum = formData.tokenMaxLimit; + const requestMaxLimitNum = formData.requestMaxLimit; + + // Validation + const validator = useMemo( + () => + new Validator([ + // Basic validation + Validator.required(formData.name.trim(), "Team name is required"), + + // Check if anything is dirty + Validator.custom(formData.isDirty, "No changes to save"), + + // Budget validation + ...(formData.budgetMaxLimit !== undefined && + formData.budgetMaxLimit !== null + ? [ + Validator.minValue( + budgetMaxLimitNum || 0, + 0.01, + "Budget max limit must be greater than $0.01", + ), + Validator.required( + formData.budgetResetDuration, + "Budget reset duration is required", + ), + ] + : []), + + // Rate limit validation - token limits + ...(formData.tokenMaxLimit !== undefined && + formData.tokenMaxLimit !== null + ? [ + Validator.minValue( + tokenMaxLimitNum || 0, + 1, + "Token max limit must be at least 1", + ), + Validator.required( + formData.tokenResetDuration, + "Token reset duration is required", + ), + ] + : []), + + // Rate limit validation - request limits + ...(formData.requestMaxLimit !== undefined && + formData.requestMaxLimit !== null + ? [ + Validator.minValue( + requestMaxLimitNum || 0, + 1, + "Request max limit must be at least 1", + ), + Validator.required( + formData.requestResetDuration, + "Request reset duration is required", + ), + ] + : []), + ]), + [formData, budgetMaxLimitNum, tokenMaxLimitNum, requestMaxLimitNum], + ); + + const updateField = ( + field: K, + value: TeamFormData[K], + ) => { + setFormData((prev) => ({ ...prev, [field]: value })); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + + if (!validator.isValid()) { + toast.error(validator.getFirstError()); + return; + } + + try { + if (isEditing && team) { + // Update existing team + const updateData: UpdateTeamRequest = { + name: formData.name, + customer_id: formData.customerId || undefined, + }; + + // Detect budget changes using had/has pattern + const hadBudget = !!team.budget; + const hasBudget = + budgetMaxLimitNum !== undefined && budgetMaxLimitNum !== null; + if (hasBudget) { + updateData.budget = { + max_limit: budgetMaxLimitNum, + reset_duration: formData.budgetResetDuration, + calendar_aligned: formData.budgetCalendarAligned, + }; + } else if (hadBudget) { + updateData.budget = {} as UpdateTeamRequest["budget"]; + } + + // Detect rate limit changes using had/has pattern + const hadRateLimit = !!team.rate_limit; + const hasRateLimit = + (tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null) || + (requestMaxLimitNum !== undefined && requestMaxLimitNum !== null); + if (hasRateLimit) { + updateData.rate_limit = { + token_max_limit: tokenMaxLimitNum, + token_reset_duration: + tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null + ? formData.tokenResetDuration + : undefined, + request_max_limit: requestMaxLimitNum, + request_reset_duration: + requestMaxLimitNum !== undefined && requestMaxLimitNum !== null + ? formData.requestResetDuration + : undefined, + }; + } else if (hadRateLimit) { + updateData.rate_limit = {} as UpdateTeamRequest["rate_limit"]; + } + + await updateTeam({ teamId: team.id, data: updateData }).unwrap(); + toast.success("Team updated successfully"); + } else { + // Create new team + const createData: CreateTeamRequest = { + name: formData.name, + customer_id: formData.customerId || undefined, + }; + + // Add budget if enabled + if (budgetMaxLimitNum !== undefined && budgetMaxLimitNum !== null) { + createData.budget = { + max_limit: budgetMaxLimitNum, + reset_duration: formData.budgetResetDuration, + calendar_aligned: formData.budgetCalendarAligned, + }; + } + + // Add rate limit if enabled (token or request limits) + if ( + (tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null) || + (requestMaxLimitNum !== undefined && requestMaxLimitNum !== null) + ) { + createData.rate_limit = { + token_max_limit: tokenMaxLimitNum, + token_reset_duration: + tokenMaxLimitNum !== undefined && tokenMaxLimitNum !== null + ? formData.tokenResetDuration + : undefined, + request_max_limit: requestMaxLimitNum, + request_reset_duration: + requestMaxLimitNum !== undefined && requestMaxLimitNum !== null + ? formData.requestResetDuration + : undefined, + }; + } + + await createTeam(createData).unwrap(); + toast.success("Team created successfully"); + } + + onSave(); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }; + + return ( + + + + + {isEditing ? "Edit Team" : "Create Team"} + + + {isEditing + ? "Update the team information and settings." + : "Create a new team to organize users and manage shared resources."} + + + +
+
+ {/* Basic Information */} +
+
+ + updateField("name", e.target.value)} + data-testid="team-name-input" + /> +
+ + {/* Customer Assignment */} + {customers?.length > 0 && ( +
+ + +

+ Assign to a customer or leave independent. +

+
+ )} +
+ + {/* Budget Configuration */} + updateField("budgetMaxLimit", value)} + onChangeSelect={(value) => { + updateField("budgetResetDuration", value); + if (!supportsCalendarAlignment(value)) { + updateField("budgetCalendarAligned", false); + } + }} + options={resetDurationOptions} + dataTestId="budget-max-limit-input" + /> + + {/* Calendar alignment toggle — only shown when a budget is set and the period supports alignment */} + {formData.budgetMaxLimit && + supportsCalendarAlignment(formData.budgetResetDuration) && ( +
+
+ +

+ Reset at the start of each period (e.g. 1st of month) + instead of rolling from creation date +

+
+ +
+ )} + + {/* Warning dialog shown when enabling calendar alignment on an existing budget */} + + + + Reset budget usage? + + Enabling calendar alignment will reset this budget's + current usage to{" "} + $0.00 and snap the + reset date to the start of the current{" "} + {formData.budgetResetDuration === "1d" + ? "day" + : formData.budgetResetDuration === "1w" + ? "week" + : formData.budgetResetDuration === "1M" + ? "month" + : formData.budgetResetDuration === "1Y" + ? "year" + : "period"} + . The usage reset to $0.00 cannot be undone, but calendar + alignment can be turned off later. This will take effect + when you save. + + + + + Cancel + + { + updateField("budgetCalendarAligned", true); + setShowCalendarAlignWarning(false); + }} + > + Enable Calendar Alignment + + + + + + {/* Rate Limit Configuration - Token Limits */} + updateField("tokenMaxLimit", value)} + onChangeSelect={(value) => + updateField("tokenResetDuration", value) + } + options={resetDurationOptions} + /> + + {/* Rate Limit Configuration - Request Limits */} + updateField("requestMaxLimit", value)} + onChangeSelect={(value) => + updateField("requestResetDuration", value) + } + options={resetDurationOptions} + /> + + {/* Current Usage Section (only shown when editing with existing limits) */} + {isEditing && (team?.budget || team?.rate_limit) && ( +
+

Current Usage

+
+ {team?.budget && ( +
+

Budget

+
+ + {formatCurrency(team.budget.current_usage)} /{" "} + {formatCurrency(team.budget.max_limit)} + + 0 && + team.budget.current_usage >= team.budget.max_limit + ? "destructive" + : "default" + } + className="text-xs" + > + {team.budget.max_limit > 0 + ? Math.round( + (team.budget.current_usage / + team.budget.max_limit) * + 100, + ) + : 0} + % + +
+

+ Last Reset:{" "} + {formatDistanceToNow(new Date(team.budget.last_reset), { + addSuffix: true, + })} +

+
+ )} + {team?.rate_limit?.token_max_limit && ( +
+

Tokens

+
+ + {team.rate_limit.token_current_usage.toLocaleString()}{" "} + / {team.rate_limit.token_max_limit.toLocaleString()} + + 0 && + team.rate_limit.token_current_usage >= + team.rate_limit.token_max_limit + ? "destructive" + : "default" + } + className="text-xs" + > + {team.rate_limit.token_max_limit > 0 + ? Math.round( + (team.rate_limit.token_current_usage / + team.rate_limit.token_max_limit) * + 100, + ) + : 0} + % + +
+

+ Last Reset:{" "} + {formatDistanceToNow( + new Date(team.rate_limit.token_last_reset), + { addSuffix: true }, + )} +

+
+ )} + {team?.rate_limit?.request_max_limit && ( +
+

Requests

+
+ + {team.rate_limit.request_current_usage.toLocaleString()}{" "} + / {team.rate_limit.request_max_limit.toLocaleString()} + + 0 && + team.rate_limit.request_current_usage >= + team.rate_limit.request_max_limit + ? "destructive" + : "default" + } + className="text-xs" + > + {team.rate_limit.request_max_limit > 0 + ? Math.round( + (team.rate_limit.request_current_usage / + team.rate_limit.request_max_limit) * + 100, + ) + : 0} + % + +
+

+ Last Reset:{" "} + {formatDistanceToNow( + new Date(team.rate_limit.request_last_reset), + { addSuffix: true }, + )} +

+
+ )} +
+
+ )} +
+ + + +
+
+ ); +} diff --git a/ui/components/ui/badge.tsx b/ui/components/ui/badge.tsx index dbae1f192a..385d44d3a8 100644 --- a/ui/components/ui/badge.tsx +++ b/ui/components/ui/badge.tsx @@ -5,33 +5,44 @@ import * as React from "react"; import { cn } from "@/lib/utils"; const badgeVariants = cva( - "inline-flex items-center justify-center rounded-sm border px-2 py-0.5 text-xs font-medium w-fit whitespace-nowrap shrink-0 [&>svg]:size-3 gap-1 [&>svg]:pointer-events-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive transition-[color,box-shadow] overflow-hidden", - { - variants: { - variant: { - default: "border-transparent bg-primary text-primary-foreground [a&]:hover:bg-primary/90", - secondary: "border-transparent bg-secondary text-secondary-foreground [a&]:hover:bg-secondary/90", - destructive: - "border-transparent bg-destructive text-white [a&]:hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60", - outline: "text-foreground [a&]:hover:bg-accent [a&]:hover:text-accent-foreground", - success: "border-transparent bg-green-700 text-white [a&]:hover:bg-green-700/90", - }, - }, - defaultVariants: { - variant: "default", - }, - }, + "inline-flex items-center justify-center rounded-sm border px-2 py-0.5 text-xs font-medium w-fit whitespace-nowrap shrink-0 [&>svg]:size-3 gap-1 [&>svg]:pointer-events-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive transition-[color,box-shadow] overflow-hidden", + { + variants: { + variant: { + default: + "border-transparent bg-primary/10 border-primary/50 text-primary [a&]:hover:bg-primary/90 [a&]:hover:text-primary-foreground", + secondary: + "border-transparent bg-secondary text-secondary-foreground [a&]:hover:bg-secondary/90", + destructive: + "border-transparent bg-destructive/10 border-destructive/50 text-black dark:text-destructive-foreground [a&]:hover:bg-destructive/90 [a&]:hover:text-destructive-foreground focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60", + outline: + "text-foreground [a&]:hover:bg-accent [a&]:hover:text-accent-foreground", + success: + "border-transparent bg-green-100 border-green-500 text-black [a&]:hover:bg-green-700/90 [a&]:hover:text-white", + }, + }, + defaultVariants: { + variant: "default", + }, + }, ); function Badge({ - className, - variant, - asChild = false, - ...props -}: React.ComponentProps<"span"> & VariantProps & { asChild?: boolean }) { - const Comp = asChild ? Slot : "span"; + className, + variant, + asChild = false, + ...props +}: React.ComponentProps<"span"> & + VariantProps & { asChild?: boolean }) { + const Comp = asChild ? Slot : "span"; - return ; + return ( + + ); } -export { Badge, badgeVariants }; \ No newline at end of file +export { Badge, badgeVariants };