diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 4d1abbbde5..15f74ee830 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -655,7 +655,7 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r streamType = streaming.StreamTypeText case schemas.ChatCompletionStreamRequest: streamType = streaming.StreamTypeChat - case schemas.ResponsesStreamRequest: + case schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest: streamType = streaming.StreamTypeResponses case schemas.SpeechStreamRequest: streamType = streaming.StreamTypeAudio diff --git a/plugins/logging/utils_test.go b/plugins/logging/utils_test.go new file mode 100644 index 0000000000..a7a74f921a --- /dev/null +++ b/plugins/logging/utils_test.go @@ -0,0 +1,95 @@ +package logging + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/streaming" +) + +// TestConvertToProcessedStreamResponseWebSocketResponsesRequest verifies that +// WebSocketResponsesRequest routes to StreamTypeResponses, not StreamTypeChat. +func TestConvertToProcessedStreamResponseWebSocketResponsesRequest(t *testing.T) { + result := &schemas.StreamAccumulatorResult{ + RequestID: "req-ws-001", + RequestedModel: "gpt-4o", + Provider: schemas.OpenAI, + Status: "success", + } + + got := convertToProcessedStreamResponse(result, schemas.WebSocketResponsesRequest) + + if got == nil { + t.Fatal("expected non-nil ProcessedStreamResponse") + } + if got.StreamType != streaming.StreamTypeResponses { + t.Errorf("StreamType = %q, want %q", got.StreamType, streaming.StreamTypeResponses) + } +} + +// TestConvertToProcessedStreamResponseResponsesStreamRequest verifies the existing +// ResponsesStreamRequest case still routes to StreamTypeResponses (regression guard). +func TestConvertToProcessedStreamResponseResponsesStreamRequest(t *testing.T) { + result := &schemas.StreamAccumulatorResult{ + RequestID: "req-resp-001", + RequestedModel: "gpt-4o", + Provider: schemas.OpenAI, + Status: "success", + } + + got := convertToProcessedStreamResponse(result, schemas.ResponsesStreamRequest) + + if got == nil { + t.Fatal("expected non-nil ProcessedStreamResponse") + } + if got.StreamType != streaming.StreamTypeResponses { + t.Errorf("StreamType = %q, want %q", got.StreamType, streaming.StreamTypeResponses) + } +} + +// TestConvertToProcessedStreamResponseWebSocketMatchesResponsesStream verifies that +// WebSocketResponsesRequest and ResponsesStreamRequest produce equivalent StreamType values. +func TestConvertToProcessedStreamResponseWebSocketMatchesResponsesStream(t *testing.T) { + result := &schemas.StreamAccumulatorResult{ + RequestID: "req-compare-001", + RequestedModel: "gpt-4o", + Provider: schemas.OpenAI, + Status: "success", + } + + wsResp := convertToProcessedStreamResponse(result, schemas.WebSocketResponsesRequest) + httpResp := convertToProcessedStreamResponse(result, schemas.ResponsesStreamRequest) + + if wsResp == nil || httpResp == nil { + t.Fatal("expected non-nil responses for both request types") + } + if wsResp.StreamType != httpResp.StreamType { + t.Errorf("StreamType mismatch: WebSocketResponsesRequest=%q, ResponsesStreamRequest=%q", + wsResp.StreamType, httpResp.StreamType) + } +} + +// TestConvertToProcessedStreamResponseNilResult verifies nil input returns nil output. +func TestConvertToProcessedStreamResponseNilResult(t *testing.T) { + got := convertToProcessedStreamResponse(nil, schemas.WebSocketResponsesRequest) + if got != nil { + t.Errorf("expected nil for nil input, got %v", got) + } +} + +// TestConvertToProcessedStreamResponseDefaultFallback verifies unknown request types +// still route to StreamTypeChat (existing behaviour). +func TestConvertToProcessedStreamResponseDefaultFallback(t *testing.T) { + result := &schemas.StreamAccumulatorResult{ + RequestID: "req-unknown-001", + } + + got := convertToProcessedStreamResponse(result, schemas.RequestType("unknown_type")) + + if got == nil { + t.Fatal("expected non-nil ProcessedStreamResponse") + } + if got.StreamType != streaming.StreamTypeChat { + t.Errorf("StreamType = %q, want %q (default fallback)", got.StreamType, streaming.StreamTypeChat) + } +}