diff --git a/framework/streaming/accumulator.go b/framework/streaming/accumulator.go index 66fe43339c..24cec17dad 100644 --- a/framework/streaming/accumulator.go +++ b/framework/streaming/accumulator.go @@ -397,7 +397,7 @@ func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, resu isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest isChatStreaming := requestType == schemas.ChatCompletionStreamRequest || requestType == schemas.TextCompletionStreamRequest - isResponsesStreaming := requestType == schemas.ResponsesStreamRequest + isResponsesStreaming := requestType == schemas.ResponsesStreamRequest || requestType == schemas.WebSocketResponsesRequest // Edit images/ Image variation requests will be added here isImageStreaming := requestType == schemas.ImageGenerationStreamRequest || requestType == schemas.ImageEditStreamRequest diff --git a/framework/streaming/accumulator_test.go b/framework/streaming/accumulator_test.go index 18eb43f71b..854d5ebfbf 100644 --- a/framework/streaming/accumulator_test.go +++ b/framework/streaming/accumulator_test.go @@ -659,3 +659,173 @@ func TestGetLastAudioAndTranscriptionChunksSafe(t *testing.T) { t.Errorf("Expected transcription chunk index 3, got %d", lastTranscription.ChunkIndex) } } + +// TestWebSocketResponsesRequestAccumulatedAsResponsesStreaming verifies that +// ProcessStreamingResponse treats WebSocketResponsesRequest the same as +// ResponsesStreamRequest so that WS chunks are accumulated and token usage +// is captured. This is a regression test for the bug described in issue #3001. +func TestWebSocketResponsesRequestAccumulatedAsResponsesStreaming(t *testing.T) { + logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) + accumulator := NewAccumulator(nil, logger) + + requestID := "test-ws-responses-request" + ctx := schemas.NewBifrostContext(context.Background(), time.Time{}) + ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID) + + const inputTokens = 120 + const outputTokens = 60 + const totalTokens = 180 + + // Add intermediate chunks before the final one. + for i := 0; i < 3; i++ { + chunk := &ResponsesStreamChunk{ + ChunkIndex: i, + Timestamp: time.Now(), + StreamResponse: &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + }, + } + err := accumulator.addResponsesStreamChunk(requestID, chunk, false) + if err != nil { + t.Fatalf("Failed to add chunk %d: %v", i, err) + } + } + + // Final chunk carries token usage. + finalChunk := &ResponsesStreamChunk{ + ChunkIndex: 3, + Timestamp: time.Now(), + TokenUsage: &schemas.BifrostLLMUsage{ + PromptTokens: inputTokens, + CompletionTokens: outputTokens, + TotalTokens: totalTokens, + }, + StreamResponse: &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + Response: &schemas.BifrostResponsesResponse{ + Usage: &schemas.ResponsesResponseUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + }, + }, + }, + } + err := accumulator.addResponsesStreamChunk(requestID, finalChunk, true) + if err != nil { + t.Fatalf("Failed to add final chunk: %v", err) + } + + // Build a response carrying WebSocketResponsesRequest as the request type. + response := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + ID: bifrost.Ptr("ws-msg-001"), + Usage: &schemas.ResponsesResponseUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.WebSocketResponsesRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", + ChunkIndex: 3, + }, + }, + } + + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + + processed, processErr := accumulator.ProcessStreamingResponse(ctx, response, nil) + + if processErr != nil { + t.Fatalf("ProcessStreamingResponse returned unexpected error for WebSocketResponsesRequest: %v", processErr) + } + if processed == nil { + t.Fatal("ProcessStreamingResponse returned nil result for WebSocketResponsesRequest") + } + if processed.Data == nil { + t.Fatal("Processed response Data is nil; accumulator did not accumulate for WebSocketResponsesRequest") + } + if processed.Data.TokenUsage == nil { + t.Fatal("Token usage is nil; accumulator did not capture usage for WebSocketResponsesRequest") + } + if processed.Data.TokenUsage.PromptTokens != inputTokens { + t.Errorf("Expected PromptTokens=%d, got %d", inputTokens, processed.Data.TokenUsage.PromptTokens) + } + if processed.Data.TokenUsage.CompletionTokens != outputTokens { + t.Errorf("Expected CompletionTokens=%d, got %d", outputTokens, processed.Data.TokenUsage.CompletionTokens) + } + if processed.Data.TokenUsage.TotalTokens != totalTokens { + t.Errorf("Expected TotalTokens=%d, got %d", totalTokens, processed.Data.TokenUsage.TotalTokens) + } +} + +// TestWebSocketResponsesRequestVsResponsesStreamRequestParity verifies that +// ProcessStreamingResponse treats WebSocketResponsesRequest identically to +// ResponsesStreamRequest and does NOT return the "request type missing/invalid" +// error that was produced before the fix. +func TestWebSocketResponsesRequestVsResponsesStreamRequestParity(t *testing.T) { + logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) + + makeAccumulatorWithChunk := func(requestType schemas.RequestType, requestID string) (*Accumulator, *schemas.BifrostContext, *schemas.BifrostResponse) { + acc := NewAccumulator(nil, logger) + + ctx := schemas.NewBifrostContext(context.Background(), time.Time{}) + ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID) + + chunk := &ResponsesStreamChunk{ + ChunkIndex: 0, + Timestamp: time.Now(), + TokenUsage: &schemas.BifrostLLMUsage{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + StreamResponse: &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + Response: &schemas.BifrostResponsesResponse{ + Usage: &schemas.ResponsesResponseUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + }, + }, + } + err := acc.addResponsesStreamChunk(requestID, chunk, true) + if err != nil { + t.Fatalf("Failed to add chunk for %s: %v", requestType, err) + } + + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + + resp := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + ID: bifrost.Ptr("parity-test"), + Usage: &schemas.ResponsesResponseUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: requestType, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", + ChunkIndex: 0, + }, + }, + } + return acc, ctx, resp + } + + for _, rt := range []schemas.RequestType{schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest} { + rt := rt + t.Run(string(rt), func(t *testing.T) { + acc, ctx, resp := makeAccumulatorWithChunk(rt, "parity-"+string(rt)) + processed, err := acc.ProcessStreamingResponse(ctx, resp, nil) + if err != nil { + t.Errorf("RequestType %s: unexpected error: %v", rt, err) + } + if processed == nil { + t.Errorf("RequestType %s: expected non-nil processed response", rt) + } + }) + } +} diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 4d1abbbde5..15f74ee830 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -655,7 +655,7 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r streamType = streaming.StreamTypeText case schemas.ChatCompletionStreamRequest: streamType = streaming.StreamTypeChat - case schemas.ResponsesStreamRequest: + case schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest: streamType = streaming.StreamTypeResponses case schemas.SpeechStreamRequest: streamType = streaming.StreamTypeAudio