diff --git a/core/bifrost.go b/core/bifrost.go index f136b4d89d..88e2a88673 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5005,6 +5005,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch if bifrostError != nil { return nil, bifrostError } + videoGenerationResponse.BackfillParams(&req.BifrostRequest) response.VideoGenerationResponse = videoGenerationResponse case schemas.VideoRetrieveRequest: videoRetrieveResponse, bifrostError := provider.VideoRetrieve(req.Context, key, req.BifrostRequest.VideoRetrieveRequest) diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 28fac109f2..fc758a16ee 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -2474,7 +2474,6 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s return nil, bifrostErr } - // Convert to Bifrost response bifrostResp, bifrostErr := ToBifrostVideoGenerationResponse(&operation, "") if bifrostErr != nil { return nil, bifrostErr diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 9815ee1676..f88594c29d 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -1454,169 +1454,54 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } if currentEvent.Event != "" { - // Process the event - switch currentEvent.Event { - case "output": - // Text chunk received - if currentEvent.Data != "" { - // Accumulate raw response if enabled - if sendBackRawResponse { - rawResponseChunks = append(rawResponseChunks, currentEvent) - } - - // Emit lifecycle events on first content - if !hasEmittedCreated { - // response.created - createdResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeCreated, - SequenceNumber: sequenceNumber, - Response: &schemas.BifrostResponsesResponse{ - ID: schemas.Ptr(messageID), - Model: request.Model, - CreatedAt: int(startTime.Unix()), - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, - }, - } - if sendBackRawRequest { - providerUtils.ParseAndSetRawRequest(&createdResp.ExtraFields, jsonData) - } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, createdResp, nil, nil, nil), - responseChan) - sequenceNumber++ - hasEmittedCreated = true - } - - if !hasEmittedInProgress { - // response.in_progress - inProgressResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeInProgress, - SequenceNumber: sequenceNumber, - Response: &schemas.BifrostResponsesResponse{ - ID: schemas.Ptr(messageID), - CreatedAt: int(startTime.Unix()), - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, - }, - } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, inProgressResp, nil, nil, nil), - responseChan) - sequenceNumber++ - hasEmittedInProgress = true - } - - if !hasEmittedOutputItemAdded { - // response.output_item.added - messageType := schemas.ResponsesMessageTypeMessage - role := schemas.ResponsesInputMessageRoleAssistant - status := "in_progress" - itemAddedResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - Item: &schemas.ResponsesMessage{ - ID: schemas.Ptr(itemID), - Type: &messageType, - Role: &role, - Status: &status, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{}, - }, - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, - }, - } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemAddedResp, nil, nil, nil), - responseChan) - sequenceNumber++ - hasEmittedOutputItemAdded = true - } - - if !hasEmittedContentPartAdded { - // response.content_part.added - emptyText := "" - partAddedResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(contentIndex), - ItemID: schemas.Ptr(itemID), - Part: &schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &emptyText, - ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{ - Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{}, - LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, - }, - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, - }, - } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partAddedResp, nil, nil, nil), - responseChan) - sequenceNumber++ - hasEmittedContentPartAdded = true - } + // Process the event + switch currentEvent.Event { + case "output": + // Text chunk received + if currentEvent.Data != "" { + // Accumulate raw response if enabled + if sendBackRawResponse { + rawResponseChunks = append(rawResponseChunks, currentEvent) + } - // response.output_text.delta - deltaResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + // Emit lifecycle events on first content + if !hasEmittedCreated { + // response.created + createdResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(contentIndex), - ItemID: schemas.Ptr(itemID), - Delta: schemas.Ptr(currentEvent.Data), - LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, + Response: &schemas.BifrostResponsesResponse{ + ID: schemas.Ptr(messageID), + Model: request.Model, + CreatedAt: int(startTime.Unix()), + }, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesStreamRequest, Provider: provider.GetProviderKey(), ModelRequested: request.Model, + Latency: time.Since(startTime).Milliseconds(), ChunkIndex: sequenceNumber, }, } + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&createdResp.ExtraFields, jsonData) + } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, deltaResp, nil, nil, nil), + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, createdResp, nil, nil, nil), responseChan) sequenceNumber++ - hasReceivedContent = true - } - case "done": - // Accumulate done event in raw responses if enabled - if sendBackRawResponse { - rawResponseChunks = append(rawResponseChunks, currentEvent) + hasEmittedCreated = true } - // Stream completed - if hasReceivedContent { - // response.output_text.done - textDoneResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + if !hasEmittedInProgress { + // response.in_progress + inProgressResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(contentIndex), - ItemID: schemas.Ptr(itemID), - LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, + Response: &schemas.BifrostResponsesResponse{ + ID: schemas.Ptr(messageID), + CreatedAt: int(startTime.Unix()), + }, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesStreamRequest, Provider: provider.GetProviderKey(), @@ -1625,22 +1510,28 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, textDoneResp, nil, nil, nil), + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, inProgressResp, nil, nil, nil), responseChan) sequenceNumber++ + hasEmittedInProgress = true + } - // response.content_part.done - partDoneResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartDone, + if !hasEmittedOutputItemAdded { + // response.output_item.added + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + status := "in_progress" + itemAddedResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, SequenceNumber: sequenceNumber, OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(contentIndex), - ItemID: schemas.Ptr(itemID), - Part: &schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{ - Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{}, - LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, + Item: &schemas.ResponsesMessage{ + ID: schemas.Ptr(itemID), + Type: &messageType, + Role: &role, + Status: &status, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ @@ -1651,33 +1542,27 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partDoneResp, nil, nil, nil), + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemAddedResp, nil, nil, nil), responseChan) sequenceNumber++ + hasEmittedOutputItemAdded = true + } - // response.output_item.done - messageType := schemas.ResponsesMessageTypeMessage - role := schemas.ResponsesInputMessageRoleAssistant - status := "completed" - itemDoneResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + if !hasEmittedContentPartAdded { + // response.content_part.added + emptyText := "" + partAddedResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, SequenceNumber: sequenceNumber, OutputIndex: schemas.Ptr(outputIndex), - Item: &schemas.ResponsesMessage{ - ID: schemas.Ptr(itemID), - Type: &messageType, - Role: &role, - Status: &status, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{ - Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{}, - LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, - }, - }, - }, + ContentIndex: schemas.Ptr(contentIndex), + ItemID: schemas.Ptr(itemID), + Part: &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &emptyText, + ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{ + Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{}, + LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ @@ -1688,80 +1573,195 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemDoneResp, nil, nil, nil), + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partAddedResp, nil, nil, nil), responseChan) sequenceNumber++ + hasEmittedContentPartAdded = true } - // response.completed - completedResp := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeCompleted, + // response.output_text.delta + deltaResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, SequenceNumber: sequenceNumber, - Response: &schemas.BifrostResponsesResponse{ - ID: schemas.Ptr(messageID), - Model: request.Model, - CreatedAt: int(startTime.Unix()), - CompletedAt: schemas.Ptr(int(time.Now().Unix())), - }, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(contentIndex), + ItemID: schemas.Ptr(itemID), + Delta: schemas.Ptr(currentEvent.Data), + LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesStreamRequest, Provider: provider.GetProviderKey(), ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), ChunkIndex: sequenceNumber, }, } + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, deltaResp, nil, nil, nil), + responseChan) + sequenceNumber++ + hasReceivedContent = true + } + case "done": + // Accumulate done event in raw responses if enabled + if sendBackRawResponse { + rawResponseChunks = append(rawResponseChunks, currentEvent) + } - // Set raw request if enabled (on final chunk only) - if sendBackRawRequest { - providerUtils.ParseAndSetRawRequest(&completedResp.ExtraFields, jsonData) + // Stream completed + if hasReceivedContent { + // response.output_text.done + textDoneResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(contentIndex), + ItemID: schemas.Ptr(itemID), + LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: provider.GetProviderKey(), + ModelRequested: request.Model, + ChunkIndex: sequenceNumber, + }, } + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, textDoneResp, nil, nil, nil), + responseChan) + sequenceNumber++ - // Set raw response if enabled - if sendBackRawResponse && len(rawResponseChunks) > 0 { - completedResp.ExtraFields.RawResponse = rawResponseChunks + // response.content_part.done + partDoneResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(contentIndex), + ItemID: schemas.Ptr(itemID), + Part: &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{ + Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{}, + LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: provider.GetProviderKey(), + ModelRequested: request.Model, + ChunkIndex: sequenceNumber, + }, } - - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, - providerUtils.GetBifrostResponseForStreamResponse(nil, nil, completedResp, nil, nil, nil), + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, partDoneResp, nil, nil, nil), responseChan) - resp.CloseBodyStream() - return - case "error": - // Accumulate error event in raw responses if enabled - if sendBackRawResponse { - rawResponseChunks = append(rawResponseChunks, currentEvent) + sequenceNumber++ + + // response.output_item.done + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + status := "completed" + itemDoneResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + Item: &schemas.ResponsesMessage{ + ID: schemas.Ptr(itemID), + Type: &messageType, + Role: &role, + Status: &status, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + ResponsesOutputMessageContentText: &schemas.ResponsesOutputMessageContentText{ + Annotations: []schemas.ResponsesOutputMessageContentTextAnnotation{}, + LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, + }, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: provider.GetProviderKey(), + ModelRequested: request.Model, + ChunkIndex: sequenceNumber, + }, } + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, itemDoneResp, nil, nil, nil), + responseChan) + sequenceNumber++ + } - // Handle error - errorMsg := "stream error" - if currentEvent.Data != "" { - errorMsg = currentEvent.Data - } - bifrostErr := providerUtils.NewBifrostOperationError( - errorMsg, - fmt.Errorf("stream error: %s", errorMsg), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + // response.completed + completedResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + SequenceNumber: sequenceNumber, + Response: &schemas.BifrostResponsesResponse{ + ID: schemas.Ptr(messageID), + Model: request.Model, + CreatedAt: int(startTime.Unix()), + CompletedAt: schemas.Ptr(int(time.Now().Unix())), + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, Provider: provider.GetProviderKey(), ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, + }, + } - // Include accumulated raw responses in error - if sendBackRawResponse && len(rawResponseChunks) > 0 { - bifrostErr.ExtraFields.RawResponse = rawResponseChunks - } + // Set raw request if enabled (on final chunk only) + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&completedResp.ExtraFields, jsonData) + } - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) - resp.CloseBodyStream() - return + // Set raw response if enabled + if sendBackRawResponse && len(rawResponseChunks) > 0 { + completedResp.ExtraFields.RawResponse = rawResponseChunks + } + + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, completedResp, nil, nil, nil), + responseChan) + resp.CloseBodyStream() + return + case "error": + // Accumulate error event in raw responses if enabled + if sendBackRawResponse { + rawResponseChunks = append(rawResponseChunks, currentEvent) + } + + // Handle error + errorMsg := "stream error" + if currentEvent.Data != "" { + errorMsg = currentEvent.Data + } + bifrostErr := providerUtils.NewBifrostOperationError( + errorMsg, + fmt.Errorf("stream error: %s", errorMsg), + provider.GetProviderKey(), + ) + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + Provider: provider.GetProviderKey(), + ModelRequested: request.Model, + RequestType: schemas.ResponsesStreamRequest, } + + // Include accumulated raw responses in error + if sendBackRawResponse && len(rawResponseChunks) > 0 { + bifrostErr.ExtraFields.RawResponse = rawResponseChunks + } + + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) + resp.CloseBodyStream() + return } + } } }() diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 73b663c276..4170dc599e 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -2423,7 +2423,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key if bifrostErr != nil { return nil, bifrostErr } - // Convert to Bifrost response using Gemini converter bifrostResp, bifrostErr := gemini.ToBifrostVideoGenerationResponse(&operation, bifrostReq.Model) if bifrostErr != nil { @@ -2559,7 +2558,6 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s return nil, bifrostErr } - // Convert to Bifrost response using Gemini converter bifrostResp, bifrostErr := gemini.ToBifrostVideoGenerationResponse(&operation, "") if bifrostErr != nil { return nil, bifrostErr diff --git a/core/schemas/videos.go b/core/schemas/videos.go index 0df26b9def..9e133c7d52 100644 --- a/core/schemas/videos.go +++ b/core/schemas/videos.go @@ -97,6 +97,9 @@ type VideoGenerationParameters struct { ExtraParams map[string]any `json:"-"` } +// DefaultVideoDuration is the default video duration in seconds for Gemini/Vertex when not specified. +const DefaultVideoDuration = "8" + // BifrostVideoGenerationResponse represents the video generation job response in bifrost format. type BifrostVideoGenerationResponse struct { ID string `json:"id,omitempty"` @@ -118,6 +121,43 @@ type BifrostVideoGenerationResponse struct { ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"` } +// getSecondsFromVideoRequest extracts Seconds from video-related requests. +func getSecondsFromVideoRequest(req *BifrostRequest) *string { + if req == nil { + return nil + } + useDefaultForSeconds := func(p ModelProvider) bool { + return p == Gemini || p == Vertex + } + if req.VideoGenerationRequest != nil { + var seconds *string + if req.VideoGenerationRequest.Params != nil { + seconds = req.VideoGenerationRequest.Params.Seconds + } + if seconds == nil && useDefaultForSeconds(req.VideoGenerationRequest.Provider) { + seconds = Ptr(DefaultVideoDuration) + } + return seconds + } + if req.VideoRemixRequest != nil && useDefaultForSeconds(req.VideoRemixRequest.Provider) { + return Ptr(DefaultVideoDuration) + } + return nil +} + +// BackfillParams populates response fields from the original request that are needed +// for cost calculation but may not be returned by the provider. +// - Seconds (duration from request params or default) +func (r *BifrostVideoGenerationResponse) BackfillParams(req *BifrostRequest) { + if r == nil || req == nil { + return + } + seconds := getSecondsFromVideoRequest(req) + if seconds != nil { + r.Seconds = seconds + } +} + // --- Video Remix --- type BifrostVideoRemixRequest struct { diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index 3b91f429d3..6ebd9c5fea 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -18,6 +18,8 @@ func normalizeProvider(p string) string { return string(schemas.Bedrock) } else if strings.Contains(p, "cohere") { return string(schemas.Cohere) + } else if strings.Contains(p, "runwayml") { + return string(schemas.Runway) } else { return p }