Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion framework/streaming/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, resu

isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest
isChatStreaming := requestType == schemas.ChatCompletionStreamRequest || requestType == schemas.TextCompletionStreamRequest
isResponsesStreaming := requestType == schemas.ResponsesStreamRequest
isResponsesStreaming := requestType == schemas.ResponsesStreamRequest || requestType == schemas.WebSocketResponsesRequest
// Edit images/ Image variation requests will be added here
isImageStreaming := requestType == schemas.ImageGenerationStreamRequest || requestType == schemas.ImageEditStreamRequest

Expand Down
170 changes: 170 additions & 0 deletions framework/streaming/accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,173 @@ func TestGetLastAudioAndTranscriptionChunksSafe(t *testing.T) {
t.Errorf("Expected transcription chunk index 3, got %d", lastTranscription.ChunkIndex)
}
}

// TestWebSocketResponsesRequestAccumulatedAsResponsesStreaming verifies that
// ProcessStreamingResponse treats WebSocketResponsesRequest the same as
// ResponsesStreamRequest so that WS chunks are accumulated and token usage
// is captured. This is a regression test for the bug described in issue #3001.
func TestWebSocketResponsesRequestAccumulatedAsResponsesStreaming(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)

requestID := "test-ws-responses-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)

const inputTokens = 120
const outputTokens = 60
const totalTokens = 180

// Add intermediate chunks before the final one.
for i := 0; i < 3; i++ {
chunk := &ResponsesStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
StreamResponse: &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputTextDelta,
},
}
err := accumulator.addResponsesStreamChunk(requestID, chunk, false)
if err != nil {
t.Fatalf("Failed to add chunk %d: %v", i, err)
}
}

// Final chunk carries token usage.
finalChunk := &ResponsesStreamChunk{
ChunkIndex: 3,
Timestamp: time.Now(),
TokenUsage: &schemas.BifrostLLMUsage{
PromptTokens: inputTokens,
CompletionTokens: outputTokens,
TotalTokens: totalTokens,
},
StreamResponse: &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeCompleted,
Response: &schemas.BifrostResponsesResponse{
Usage: &schemas.ResponsesResponseUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
},
},
},
}
err := accumulator.addResponsesStreamChunk(requestID, finalChunk, true)
if err != nil {
t.Fatalf("Failed to add final chunk: %v", err)
}

// Build a response carrying WebSocketResponsesRequest as the request type.
response := &schemas.BifrostResponse{
ResponsesResponse: &schemas.BifrostResponsesResponse{
ID: bifrost.Ptr("ws-msg-001"),
Usage: &schemas.ResponsesResponseUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.WebSocketResponsesRequest,
Provider: schemas.OpenAI,
OriginalModelRequested: "gpt-4o",
ChunkIndex: 3,
},
},
}

ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)

processed, processErr := accumulator.ProcessStreamingResponse(ctx, response, nil)

if processErr != nil {
t.Fatalf("ProcessStreamingResponse returned unexpected error for WebSocketResponsesRequest: %v", processErr)
}
if processed == nil {
t.Fatal("ProcessStreamingResponse returned nil result for WebSocketResponsesRequest")
}
if processed.Data == nil {
t.Fatal("Processed response Data is nil; accumulator did not accumulate for WebSocketResponsesRequest")
}
if processed.Data.TokenUsage == nil {
t.Fatal("Token usage is nil; accumulator did not capture usage for WebSocketResponsesRequest")
}
if processed.Data.TokenUsage.PromptTokens != inputTokens {
t.Errorf("Expected PromptTokens=%d, got %d", inputTokens, processed.Data.TokenUsage.PromptTokens)
}
if processed.Data.TokenUsage.CompletionTokens != outputTokens {
t.Errorf("Expected CompletionTokens=%d, got %d", outputTokens, processed.Data.TokenUsage.CompletionTokens)
}
if processed.Data.TokenUsage.TotalTokens != totalTokens {
t.Errorf("Expected TotalTokens=%d, got %d", totalTokens, processed.Data.TokenUsage.TotalTokens)
}
}

// TestWebSocketResponsesRequestVsResponsesStreamRequestParity verifies that
// ProcessStreamingResponse treats WebSocketResponsesRequest identically to
// ResponsesStreamRequest and does NOT return the "request type missing/invalid"
// error that was produced before the fix.
func TestWebSocketResponsesRequestVsResponsesStreamRequestParity(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)

makeAccumulatorWithChunk := func(requestType schemas.RequestType, requestID string) (*Accumulator, *schemas.BifrostContext, *schemas.BifrostResponse) {
acc := NewAccumulator(nil, logger)

ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)

chunk := &ResponsesStreamChunk{
ChunkIndex: 0,
Timestamp: time.Now(),
TokenUsage: &schemas.BifrostLLMUsage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
StreamResponse: &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeCompleted,
Response: &schemas.BifrostResponsesResponse{
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
},
}
err := acc.addResponsesStreamChunk(requestID, chunk, true)
if err != nil {
t.Fatalf("Failed to add chunk for %s: %v", requestType, err)
}

ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)

resp := &schemas.BifrostResponse{
ResponsesResponse: &schemas.BifrostResponsesResponse{
ID: bifrost.Ptr("parity-test"),
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 10,
OutputTokens: 5,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: requestType,
Provider: schemas.OpenAI,
OriginalModelRequested: "gpt-4o",
ChunkIndex: 0,
},
},
}
return acc, ctx, resp
}

for _, rt := range []schemas.RequestType{schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest} {
rt := rt
t.Run(string(rt), func(t *testing.T) {
acc, ctx, resp := makeAccumulatorWithChunk(rt, "parity-"+string(rt))
processed, err := acc.ProcessStreamingResponse(ctx, resp, nil)
if err != nil {
t.Errorf("RequestType %s: unexpected error: %v", rt, err)
}
if processed == nil {
t.Errorf("RequestType %s: expected non-nil processed response", rt)
}
})
}
}
2 changes: 1 addition & 1 deletion plugins/logging/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r
streamType = streaming.StreamTypeText
case schemas.ChatCompletionStreamRequest:
streamType = streaming.StreamTypeChat
case schemas.ResponsesStreamRequest:
case schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest:
streamType = streaming.StreamTypeResponses
case schemas.SpeechStreamRequest:
streamType = streaming.StreamTypeAudio
Expand Down