diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb2..2f94b553af 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1 @@ +- fix: preserve context values in async requests \ No newline at end of file diff --git a/framework/logstore/asyncjob.go b/framework/logstore/asyncjob.go index 38173c6840..0eb2e8c7de 100644 --- a/framework/logstore/asyncjob.go +++ b/framework/logstore/asyncjob.go @@ -9,6 +9,7 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/valyala/fasthttp" @@ -80,11 +81,13 @@ func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValu } // SubmitJob creates a pending job, starts background execution, and returns the job record. -func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) { +func (e *AsyncJobExecutor) SubmitJob(bifrostCtx *schemas.BifrostContext, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) { if resultTTL <= 0 { resultTTL = DefaultAsyncJobResultTTL } + virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) + var virtualKeyID *string if virtualKeyValue != nil { vk, ok := e.governanceStore.GetVirtualKey(*virtualKeyValue) @@ -109,15 +112,24 @@ func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, ope return nil, fmt.Errorf("failed to create async job: %w", err) } - go e.executeJob(job.ID, job.ResultTTL, operation) + go e.executeJob(job.ID, job.ResultTTL, operation, bifrostCtx.GetUserValues()) return job, nil } // executeJob runs the operation in the background and updates the job record. -func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation) { +func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation, contextValues map[any]any) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + // Restore original request context values (virtual key, tracing headers, etc.) + for k, v := range contextValues { + ctx.SetValue(k, v) + } + + ctx.ClearValue(schemas.BifrostContextKeyTraceID) + ctx.ClearValue(schemas.BifrostContextKeyParentSpanID) + ctx.ClearValue(schemas.BifrostContextKeySpanID) + markFailed := func(msg string) { now := time.Now().UTC() expiresAt := now.Add(time.Duration(resultTTL) * time.Second) @@ -284,3 +296,13 @@ func (c *AsyncJobCleaner) cleanupExpiredJobs(ctx context.Context) { c.logger.Warn("async job cleanup: deleted %d stale processing jobs (stuck > %dh)", staleDeleted, asyncJobStaleProcessingHours) } } + +// getVirtualKeyFromContext extracts the virtual key value from context. +// Returns nil if no VK is present (e.g., direct key mode or no governance). +func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string { + vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) + if vkValue == "" { + return nil + } + return &vkValue +} diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index d85504cfcf..b50a28172a 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -118,11 +118,10 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.TextCompletionRequest(bgCtx, bifrostTextReq) @@ -156,11 +155,10 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.ChatCompletionRequest(bgCtx, bifrostChatReq) @@ -194,11 +192,10 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.ResponsesRequest(bgCtx, bifrostResponsesReq) @@ -228,11 +225,10 @@ func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.EmbeddingRequest(bgCtx, bifrostEmbeddingReq) @@ -266,11 +262,10 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.SpeechRequest(bgCtx, bifrostSpeechReq) @@ -304,11 +299,10 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.TranscriptionRequest(bgCtx, bifrostTranscriptionReq) @@ -342,11 +336,10 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.ImageGenerationRequest(bgCtx, bifrostReq) @@ -380,11 +373,10 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.ImageEditRequest(bgCtx, bifrostReq) @@ -413,11 +405,10 @@ func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.ImageVariationRequest(bgCtx, bifrostReq) @@ -446,11 +437,10 @@ func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.RerankRequest(bgCtx, bifrostReq) @@ -479,11 +469,10 @@ func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) { } defer cancel() - virtualKeyValue := getVirtualKeyFromContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL) job, err := h.executor.SubmitJob( - virtualKeyValue, + bifrostCtx, resultTTL, func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { return h.client.OCRRequest(bgCtx, bifrostReq) diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index cee9721eb9..d1aa276a8d 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -1474,7 +1474,6 @@ func (g *GenericRouter) handleAsyncCreate( } operationType := config.GetHTTPRequestType(ctx) - vkValue := getVirtualKeyFromBifrostContext(bifrostCtx) resultTTL := getResultTTLFromHeaderWithDefault(ctx, g.handlerStore.GetAsyncJobResultTTL()) // The operation closure runs the Bifrost client call in the background. @@ -1491,7 +1490,7 @@ func (g *GenericRouter) handleAsyncCreate( } } - job, err := executor.SubmitJob(vkValue, resultTTL, operation, operationType) + job, err := executor.SubmitJob(bifrostCtx, resultTTL, operation, operationType) if err != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to create async job"))