Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions framework/changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- fix: preserve context values in async requests
28 changes: 25 additions & 3 deletions framework/logstore/asyncjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/bytedance/sonic"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -80,11 +81,13 @@ func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValu
}

// SubmitJob creates a pending job, starts background execution, and returns the job record.
func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) {
func (e *AsyncJobExecutor) SubmitJob(bifrostCtx *schemas.BifrostContext, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) {
if resultTTL <= 0 {
resultTTL = DefaultAsyncJobResultTTL
}

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)

var virtualKeyID *string
if virtualKeyValue != nil {
vk, ok := e.governanceStore.GetVirtualKey(*virtualKeyValue)
Expand All @@ -109,15 +112,24 @@ func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, ope
return nil, fmt.Errorf("failed to create async job: %w", err)
}

go e.executeJob(job.ID, job.ResultTTL, operation)
go e.executeJob(job.ID, job.ResultTTL, operation, bifrostCtx.GetUserValues())
Comment thread
TejasGhatte marked this conversation as resolved.

return job, nil
}

// executeJob runs the operation in the background and updates the job record.
func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation) {
func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation, contextValues map[any]any) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

// Restore original request context values (virtual key, tracing headers, etc.)
for k, v := range contextValues {
ctx.SetValue(k, v)
}

ctx.ClearValue(schemas.BifrostContextKeyTraceID)
ctx.ClearValue(schemas.BifrostContextKeyParentSpanID)
ctx.ClearValue(schemas.BifrostContextKeySpanID)

markFailed := func(msg string) {
now := time.Now().UTC()
expiresAt := now.Add(time.Duration(resultTTL) * time.Second)
Expand Down Expand Up @@ -284,3 +296,13 @@ func (c *AsyncJobCleaner) cleanupExpiredJobs(ctx context.Context) {
c.logger.Warn("async job cleanup: deleted %d stale processing jobs (stuck > %dh)", staleDeleted, asyncJobStaleProcessingHours)
}
}

// getVirtualKeyFromContext extracts the virtual key value from context.
// Returns nil if no VK is present (e.g., direct key mode or no governance).
func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string {
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
if vkValue == "" {
return nil
}
return &vkValue
}
33 changes: 11 additions & 22 deletions transports/bifrost-http/handlers/asyncinference.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,10 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.TextCompletionRequest(bgCtx, bifrostTextReq)
Expand Down Expand Up @@ -156,11 +155,10 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ChatCompletionRequest(bgCtx, bifrostChatReq)
Expand Down Expand Up @@ -194,11 +192,10 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ResponsesRequest(bgCtx, bifrostResponsesReq)
Expand Down Expand Up @@ -228,11 +225,10 @@ func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.EmbeddingRequest(bgCtx, bifrostEmbeddingReq)
Expand Down Expand Up @@ -266,11 +262,10 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.SpeechRequest(bgCtx, bifrostSpeechReq)
Expand Down Expand Up @@ -304,11 +299,10 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.TranscriptionRequest(bgCtx, bifrostTranscriptionReq)
Expand Down Expand Up @@ -342,11 +336,10 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ImageGenerationRequest(bgCtx, bifrostReq)
Expand Down Expand Up @@ -380,11 +373,10 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ImageEditRequest(bgCtx, bifrostReq)
Expand Down Expand Up @@ -413,11 +405,10 @@ func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.ImageVariationRequest(bgCtx, bifrostReq)
Expand Down Expand Up @@ -446,11 +437,10 @@ func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.RerankRequest(bgCtx, bifrostReq)
Expand Down Expand Up @@ -479,11 +469,10 @@ func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) {
}
defer cancel()

virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)

job, err := h.executor.SubmitJob(
virtualKeyValue,
bifrostCtx,
resultTTL,
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
return h.client.OCRRequest(bgCtx, bifrostReq)
Expand Down
3 changes: 1 addition & 2 deletions transports/bifrost-http/integrations/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,6 @@ func (g *GenericRouter) handleAsyncCreate(
}

operationType := config.GetHTTPRequestType(ctx)
vkValue := getVirtualKeyFromBifrostContext(bifrostCtx)
resultTTL := getResultTTLFromHeaderWithDefault(ctx, g.handlerStore.GetAsyncJobResultTTL())

// The operation closure runs the Bifrost client call in the background.
Expand All @@ -1491,7 +1490,7 @@ func (g *GenericRouter) handleAsyncCreate(
}
}

job, err := executor.SubmitJob(vkValue, resultTTL, operation, operationType)
job, err := executor.SubmitJob(bifrostCtx, resultTTL, operation, operationType)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if err != nil {
g.sendError(ctx, bifrostCtx, config.ErrorConverter,
newBifrostError(err, "failed to create async job"))
Expand Down
Loading