Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion transports/bifrost-http/handlers/wsresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package handlers

import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"

"github.com/bytedance/sonic"
"github.com/fasthttp/router"
Expand Down Expand Up @@ -302,14 +306,44 @@ func (h *WSResponsesHandler) tryNativeWSUpstream(
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)

// Determine the per-read idle deadline from provider config (or fallback constant).
idleTimeout := h.upstreamWSIdleTimeout(req.Provider)

// Read response events from upstream and relay to client, running post-hooks per chunk
forwardedAny := false
for {
// Set a per-read deadline so that a silent upstream stall (e.g. rate-limit
// hold, service degradation) does not block this goroutine indefinitely.
// The deadline is cleared after each successful read so that a long-running
// stream that sends a frame every <idleTimeout is never interrupted.
if setErr := upstream.SetReadDeadline(time.Now().Add(idleTimeout)); setErr != nil {
logger.Warn("upstream WS SetReadDeadline failed for %s: %v", req.Provider, setErr)
}

msgType, data, readErr := upstream.ReadMessage()

// Clear the deadline immediately after a successful read so subsequent
// reads start their own fresh idle window.
if readErr == nil {
_ = upstream.SetReadDeadline(time.Time{})
}

if readErr != nil {
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
h.pool.Discard(upstream)
session.SetUpstream(nil)

// Detect idle timeout: the upstream accepted the connection but sent no
// frame within idleTimeout. Return a 504 to the client rather than
// silently blocking.
var netErr net.Error
if errors.As(readErr, &netErr) && netErr.Timeout() {
timeoutMsg := fmt.Sprintf("upstream websocket idle timeout after %.0fs", idleTimeout.Seconds())
logger.Warn("upstream WS idle timeout for %s (no frame in %.0fs)", req.Provider, idleTimeout.Seconds())
writeWSError(session, 504, "upstream_timeout", timeoutMsg)
return true
}

logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
if !forwardedAny {
return false
}
Expand Down Expand Up @@ -354,6 +388,39 @@ func (h *WSResponsesHandler) tryNativeWSUpstream(
}
}

// wsUpstreamIdleTimeout is the default idle timeout applied to each
// upstream.ReadMessage() call inside tryNativeWSUpstream. If no data frame
// arrives from the upstream within this window the read is cancelled, the
// connection is discarded, and a 504 error is returned to the client.
//
// "Idle" means time since the last frame received, not total request time.
// After each successful read the deadline is cleared so a stream that sends a
// frame every 30 s never trips the 60 s idle limit.
//
// The value is sourced from NetworkConfig.StreamIdleTimeoutInSeconds when the
// provider config is available. This constant is the fallback when no per-
// provider override is configured (matches DefaultStreamIdleTimeoutInSeconds).
const wsUpstreamIdleTimeout = 60 * time.Second

// upstreamWSIdleTimeout returns the idle timeout to use for upstream WS reads
// for the given provider. It sources the value from the provider's
// NetworkConfig.StreamIdleTimeoutInSeconds when available, falling back to
// wsUpstreamIdleTimeout (60 s) otherwise.
func (h *WSResponsesHandler) upstreamWSIdleTimeout(provider schemas.ModelProvider) time.Duration {
if h.config == nil {
return wsUpstreamIdleTimeout
}
cfg, err := h.config.GetProviderConfigRaw(provider)
if err != nil || cfg == nil {
return wsUpstreamIdleTimeout
}
nc := cfg.NetworkConfig
if nc == nil || nc.StreamIdleTimeoutInSeconds <= 0 {
return wsUpstreamIdleTimeout
}
return time.Duration(nc.StreamIdleTimeoutInSeconds) * time.Second
}

// writeWSShortCircuitResponse writes a short-circuited plugin response as WS events.
func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) {
if resp.ResponsesResponse != nil {
Expand Down
24 changes: 24 additions & 0 deletions transports/bifrost-http/handlers/wsresponses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers

import (
"testing"
"time"

"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
Expand Down Expand Up @@ -66,3 +67,26 @@ func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T)
t.Fatalf("parent request id should be unset, got %#v", got)
}
}

// ---------------------------------------------------------------------------
// upstreamWSIdleTimeout: config resolution
// ---------------------------------------------------------------------------

// TestUpstreamWSIdleTimeout_FallbackWhenConfigNil verifies that the default
// 60 s constant is returned when h.config is nil (no per-provider override).
func TestUpstreamWSIdleTimeout_FallbackWhenConfigNil(t *testing.T) {
h := &WSResponsesHandler{config: nil}
got := h.upstreamWSIdleTimeout(schemas.OpenAI)
if got != wsUpstreamIdleTimeout {
t.Errorf("upstreamWSIdleTimeout = %v, want %v (wsUpstreamIdleTimeout)", got, wsUpstreamIdleTimeout)
}
}

// TestUpstreamWSIdleTimeout_DefaultIs60s verifies the constant value itself is
// 60 s so that documentation comments remain accurate.
func TestUpstreamWSIdleTimeout_DefaultIs60s(t *testing.T) {
want := 60 * time.Second
if wsUpstreamIdleTimeout != want {
t.Errorf("wsUpstreamIdleTimeout = %v, want %v", wsUpstreamIdleTimeout, want)
}
}
151 changes: 151 additions & 0 deletions transports/bifrost-http/websocket/pool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package websocket

import (
"errors"
"net"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -128,6 +130,155 @@ func TestPoolClose(t *testing.T) {
assert.Error(t, err)
}

// ---------------------------------------------------------------------------
// Idle-timeout / SetReadDeadline behaviour tests.
// These tests exercise UpstreamConn.SetReadDeadline directly so that the
// tryNativeWSUpstream idle-timeout logic (which calls SetReadDeadline before
// each ReadMessage) is covered at the lowest possible level.
// ---------------------------------------------------------------------------

// TestUpstreamConnReadDeadline_Timeout verifies that a read that is given a
// very short deadline fails with a timeout error when the server never sends.
func TestUpstreamConnReadDeadline_Timeout(t *testing.T) {
upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
// Intentionally block forever, simulating an upstream stall.
select {}
}))
defer server.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
wsConn, _, err := Dial(wsURL, nil)
require.NoError(t, err)
uc := newUpstreamConn(wsConn, schemas.OpenAI, "k1", wsURL)
defer uc.Close()

const shortDeadline = 100 * time.Millisecond
start := time.Now()
require.NoError(t, uc.SetReadDeadline(time.Now().Add(shortDeadline)))
_, _, readErr := uc.ReadMessage()
elapsed := time.Since(start)

require.Error(t, readErr, "expected read to fail with timeout")

var netErr net.Error
require.True(t, errors.As(readErr, &netErr) && netErr.Timeout(),
"expected a net.Error with Timeout()=true, got: %v", readErr)

// Elapsed time should be close to the deadline, not many seconds.
assert.Less(t, elapsed, shortDeadline+500*time.Millisecond,
"read should have timed out quickly")
}

// TestUpstreamConnReadDeadline_PeriodicFramesNoTimeout verifies that an
// upstream that sends a frame every frameInterval does not trigger a timeout
// when the deadline is longer than the interval. Each successful read clears
// the deadline (as tryNativeWSUpstream does) so the stream stays alive.
func TestUpstreamConnReadDeadline_PeriodicFramesNoTimeout(t *testing.T) {
const frameInterval = 80 * time.Millisecond
const idleTimeout = 300 * time.Millisecond
const numFrames = 4

upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for range numFrames {
time.Sleep(frameInterval)
if werr := conn.WriteMessage(ws.TextMessage, []byte(`{"type":"ping"}`)); werr != nil {
return
}
}
}))
defer server.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
wsConn, _, err := Dial(wsURL, nil)
require.NoError(t, err)
uc := newUpstreamConn(wsConn, schemas.OpenAI, "k1", wsURL)
defer uc.Close()

received := 0
for {
// Replicate the per-read deadline pattern from tryNativeWSUpstream.
require.NoError(t, uc.SetReadDeadline(time.Now().Add(idleTimeout)))
_, _, readErr := uc.ReadMessage()
if readErr != nil {
// Server closed cleanly after numFrames, not a timeout.
var netErr net.Error
if errors.As(readErr, &netErr) && netErr.Timeout() {
t.Fatalf("unexpected timeout after %d frames (interval %v < deadline %v)", received, frameInterval, idleTimeout)
}
break
}
// Clear deadline after successful read (mirrors tryNativeWSUpstream).
_ = uc.SetReadDeadline(time.Time{})
received++
}
assert.Equal(t, numFrames, received, "expected to receive all frames without timeout")
}

// TestUpstreamConnReadDeadline_OneThenSilent verifies that a timeout fires
// after idleness following the first frame, not at request-start + timeout.
// The server sends one frame immediately and then goes silent.
func TestUpstreamConnReadDeadline_OneThenSilent(t *testing.T) {
const idleTimeout = 150 * time.Millisecond

upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
// Send exactly one frame, then stall.
conn.WriteMessage(ws.TextMessage, []byte(`{"type":"response.created"}`)) //nolint:errcheck
// Block indefinitely, simulating upstream stall after initial frame.
select {}
}))
defer server.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
wsConn, _, err := Dial(wsURL, nil)
require.NoError(t, err)
uc := newUpstreamConn(wsConn, schemas.OpenAI, "k1", wsURL)
defer uc.Close()

// First read: should succeed within idleTimeout.
require.NoError(t, uc.SetReadDeadline(time.Now().Add(idleTimeout)))
_, _, firstErr := uc.ReadMessage()
require.NoError(t, firstErr, "first read (one frame sent) should succeed")
// Clear deadline, mirrors tryNativeWSUpstream on a successful read.
_ = uc.SetReadDeadline(time.Time{})

// Second read: server is now silent. Set a new idle deadline.
start := time.Now()
require.NoError(t, uc.SetReadDeadline(time.Now().Add(idleTimeout)))
_, _, secondErr := uc.ReadMessage()
elapsed := time.Since(start)

require.Error(t, secondErr, "second read should fail (upstream stalled)")
var netErr net.Error
require.True(t, errors.As(secondErr, &netErr) && netErr.Timeout(),
"expected timeout error on second read, got: %v", secondErr)

// The timeout should have fired approximately idleTimeout after the SECOND
// read attempt, not at request-start + idleTimeout. Verify it did NOT fire
// instantly (i.e. the first read succeeded and reset the clock).
assert.GreaterOrEqual(t, elapsed, idleTimeout/2,
"timeout should not fire before the idle deadline expires")
assert.Less(t, elapsed, idleTimeout+500*time.Millisecond,
"timeout should fire close to idleTimeout after stall begins")
}

func TestPoolExpiredConnection(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
Expand Down