Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions transports/bifrost-http/handlers/wsresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,20 +304,63 @@ func (h *WSResponsesHandler) tryNativeWSUpstream(

// Read response events from upstream and relay to client, running post-hooks per chunk
forwardedAny := false
terminalPostHookFired := false
for {
msgType, data, readErr := upstream.ReadMessage()
if readErr != nil {
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
h.pool.Discard(upstream)
session.SetUpstream(nil)

if !forwardedAny {
// Nothing was forwarded yet: fall through to HTTP bridge.
logger.Warn("upstream WS read failed for %s (no data forwarded): %v, falling back to HTTP bridge", req.Provider, readErr)
return false
}
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")

// We already forwarded content. If the terminal post-hook has not
// fired yet (full struct parse failed or the upstream closed without
// sending a recognized terminal event), synthesize a completed event
// so the logging plugin writes the DB row and the UI transitions out
// of "processing".
if !terminalPostHookFired {
isClean := ws.IsCloseError(readErr, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived)
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
synth := synthesizeTerminalStreamResponse(req.Provider, req.Model, schemas.ResponsesStreamResponseTypeCompleted)
synthResp := &schemas.BifrostResponse{ResponsesStreamResponse: synth}
if tracer != nil && traceID != "" {
tracer.AddStreamingChunk(traceID, synthResp)
}
hooks.PostHookRunner(ctx, synthResp, nil) //nolint:errcheck
if isClean {
logger.Debug("upstream WS closed cleanly for %s without terminal event; synthesized response.completed for logging", req.Provider)
} else {
logger.Warn("upstream WS read failed for %s after forwarding data: %v; finalizing log as error", req.Provider, readErr)
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")
}
}
return true
}

streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model)

// When the full parse failed, attempt a lightweight type extraction so
// that terminal-event detection works even for event shapes that don't
// fully unmarshal (e.g. large response.completed frames or provider-
// specific events with unknown nested fields).
if streamResp == nil {
if rawType := extractStreamEventType(data); rawType != "" {
if isTerminalStreamType(rawType) {
// Full parse failed but we can still detect the terminal event.
// Synthesize a minimal struct so the post-hook and log finalize.
streamResp = synthesizeTerminalStreamResponse(req.Provider, req.Model, rawType)
} else {
logger.Debug("upstream WS event parse failed for type=%q; relaying raw bytes without post-hook", rawType)
}
} else {
logger.Debug("upstream WS event parse failed and type extraction failed; relaying raw bytes without post-hook")
}
}

isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type)

if isTerminal {
Expand All @@ -338,6 +381,9 @@ func (h *WSResponsesHandler) tryNativeWSUpstream(
writeWSBifrostError(session, postErr)
return true
}
if isTerminal {
terminalPostHookFired = true
}
}

if writeErr := session.WriteMessage(msgType, data); writeErr != nil {
Expand Down Expand Up @@ -393,6 +439,36 @@ func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model str
return &streamResp
}

// extractStreamEventType performs a minimal parse of a raw upstream WS event to
// extract the "type" field without requiring the full BifrostResponsesStreamResponse
// struct to unmarshal successfully. This is used as a fallback when
// parseUpstreamWSEvent returns nil (e.g., when the upstream sends events with
// unknown or incompatible field shapes such as large response.completed frames).
// Returns an empty string if the JSON is malformed or has no "type" field.
func extractStreamEventType(data []byte) schemas.ResponsesStreamResponseType {
var envelope struct {
Type string `json:"type"`
}
if err := sonic.Unmarshal(data, &envelope); err != nil {
return ""
}
return schemas.ResponsesStreamResponseType(envelope.Type)
}

// synthesizeTerminalStreamResponse builds a minimal BifrostResponsesStreamResponse
// with the given terminal event type. It populates ExtraFields so that downstream
// plugins (logging, tracing) can identify the request type, provider, and model.
// Used when the full upstream parse fails but we still need to finalize the log.
func synthesizeTerminalStreamResponse(provider schemas.ModelProvider, model string, eventType schemas.ResponsesStreamResponseType) *schemas.BifrostResponsesStreamResponse {
synth := &schemas.BifrostResponsesStreamResponse{
Type: eventType,
}
synth.ExtraFields.RequestType = schemas.ResponsesStreamRequest
synth.ExtraFields.Provider = provider
synth.ExtraFields.OriginalModelRequested = model
return synth
}

// isTerminalStreamType returns true if the event type signals the end of a response stream.
func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool {
switch t {
Expand Down
97 changes: 97 additions & 0 deletions transports/bifrost-http/handlers/wsresponses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,100 @@ func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T)
t.Fatalf("parent request id should be unset, got %#v", got)
}
}

// ---------------------------------------------------------------------------
// extractStreamEventType: lightweight type field extraction
// ---------------------------------------------------------------------------

func TestExtractStreamEventType_ValidTerminal(t *testing.T) {
got := extractStreamEventType([]byte(`{"type":"response.completed","sequence_number":5}`))
if got != schemas.ResponsesStreamResponseTypeCompleted {
t.Errorf("got %q, want %q", got, schemas.ResponsesStreamResponseTypeCompleted)
}
}

func TestExtractStreamEventType_ValidNonTerminal(t *testing.T) {
got := extractStreamEventType([]byte(`{"type":"response.output_text.delta","delta":"hello"}`))
if got != schemas.ResponsesStreamResponseTypeOutputTextDelta {
t.Errorf("got %q, want %q", got, schemas.ResponsesStreamResponseTypeOutputTextDelta)
}
}

func TestExtractStreamEventType_MalformedJSON(t *testing.T) {
got := extractStreamEventType([]byte(`not json at all`))
if got != "" {
t.Errorf("expected empty string for malformed JSON, got %q", got)
}
}

func TestExtractStreamEventType_MissingTypeField(t *testing.T) {
got := extractStreamEventType([]byte(`{"sequence_number":1,"delta":"hello"}`))
if got != "" {
t.Errorf("expected empty string for missing type field, got %q", got)
}
}

func TestExtractStreamEventType_UnknownExtraFields(t *testing.T) {
// Simulates a large or provider-specific event with unknown nested structure.
raw := []byte(`{"type":"response.completed","some_unknown_field":{"nested":{"deeply":"yes"}},"another_unknown":null}`)
got := extractStreamEventType(raw)
if got != schemas.ResponsesStreamResponseTypeCompleted {
t.Errorf("got %q, want %q", got, schemas.ResponsesStreamResponseTypeCompleted)
}
}

// ---------------------------------------------------------------------------
// synthesizeTerminalStreamResponse: minimal struct construction
// ---------------------------------------------------------------------------

func TestSynthesizeTerminalStreamResponse_FieldsPopulated(t *testing.T) {
resp := synthesizeTerminalStreamResponse(schemas.OpenAI, "gpt-4o", schemas.ResponsesStreamResponseTypeCompleted)
if resp == nil {
t.Fatal("got nil response")
}
if resp.Type != schemas.ResponsesStreamResponseTypeCompleted {
t.Errorf("Type = %q, want %q", resp.Type, schemas.ResponsesStreamResponseTypeCompleted)
}
if resp.ExtraFields.Provider != schemas.OpenAI {
t.Errorf("Provider = %q, want %q", resp.ExtraFields.Provider, schemas.OpenAI)
}
if resp.ExtraFields.OriginalModelRequested != "gpt-4o" {
t.Errorf("OriginalModelRequested = %q, want %q", resp.ExtraFields.OriginalModelRequested, "gpt-4o")
}
if resp.ExtraFields.RequestType != schemas.ResponsesStreamRequest {
t.Errorf("RequestType = %v, want %v", resp.ExtraFields.RequestType, schemas.ResponsesStreamRequest)
}
}

// ---------------------------------------------------------------------------
// isTerminalStreamType: terminal detection
// ---------------------------------------------------------------------------

func TestIsTerminalStreamType_TerminalTypes(t *testing.T) {
terminals := []schemas.ResponsesStreamResponseType{
schemas.ResponsesStreamResponseTypeCompleted,
schemas.ResponsesStreamResponseTypeFailed,
schemas.ResponsesStreamResponseTypeIncomplete,
schemas.ResponsesStreamResponseTypeError,
}
for _, tt := range terminals {
if !isTerminalStreamType(tt) {
t.Errorf("expected %q to be terminal", tt)
}
}
}

func TestIsTerminalStreamType_NonTerminalTypes(t *testing.T) {
nonTerminals := []schemas.ResponsesStreamResponseType{
schemas.ResponsesStreamResponseTypeOutputTextDelta,
schemas.ResponsesStreamResponseTypeCreated,
schemas.ResponsesStreamResponseTypeInProgress,
schemas.ResponsesStreamResponseType("codex.rate_limits"),
schemas.ResponsesStreamResponseType(""),
}
for _, tt := range nonTerminals {
if isTerminalStreamType(tt) {
t.Errorf("expected %q to be non-terminal", tt)
}
}
}