diff --git a/transports/bifrost-http/websocket/pool.go b/transports/bifrost-http/websocket/pool.go index c850b58e6d..47c7896f09 100644 --- a/transports/bifrost-http/websocket/pool.go +++ b/transports/bifrost-http/websocket/pool.go @@ -1,7 +1,9 @@ package websocket import ( + "errors" "fmt" + "net" "sync" "time" @@ -69,6 +71,19 @@ func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error continue } + // Liveness probe: attempt a non-blocking read with a 1ms deadline. + // An idle connection should produce no frames while sitting in the pool. + // If the upstream already sent a close frame, ReadMessage returns + // immediately with a close/EOF error, revealing the stale state before + // the caller uses the connection. A timeout error means the connection + // is alive and has nothing queued. + if !p.isLive(conn) { + conn.Close() + p.mu.Lock() + conns = p.idle[key] + continue + } + p.mu.Lock() p.inFlight++ p.mu.Unlock() @@ -199,6 +214,34 @@ func (p *Pool) evictLoop() { } } +// isLive performs a network-level liveness probe on conn. +// It sets a 1ms read deadline and attempts a ReadMessage. A net.Error with +// Timeout() == true means the connection is alive with nothing queued, which +// is the expected state for an idle pooled connection. Any other outcome +// (close frame, EOF, protocol error) means the connection is stale. +// The read deadline is always reset to zero before the function returns so +// that a live connection can be used without deadline constraints. +func (p *Pool) isLive(conn *UpstreamConn) bool { + if err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { + return false + } + _, _, err := conn.ReadMessage() + // Always clear the deadline regardless of outcome. + _ = conn.SetReadDeadline(time.Time{}) + if err == nil { + // Received a frame while idle — unexpected; treat as stale because + // we cannot safely discard the frame content here. + return false + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // Deadline fired with no data: connection is alive. + return true + } + // Close frame, EOF, or other error: connection is stale. + return false +} + func (p *Pool) evictExpired() { p.mu.Lock() defer p.mu.Unlock() diff --git a/transports/bifrost-http/websocket/pool_test.go b/transports/bifrost-http/websocket/pool_test.go index 9f734656ab..aec7d73def 100644 --- a/transports/bifrost-http/websocket/pool_test.go +++ b/transports/bifrost-http/websocket/pool_test.go @@ -128,6 +128,83 @@ func TestPoolClose(t *testing.T) { assert.Error(t, err) } +// TestPoolGetEvictsStaleSessionConn verifies that Pool.Get detects a +// server-side close via the liveness probe and dials a fresh connection +// instead of handing out the stale one (issue #3002). +func TestPoolGetEvictsStaleSessionConn(t *testing.T) { + // closeCh is closed by the test to signal the server to close connection #1. + closeCh := make(chan struct{}) + // dialCh receives a value each time the mock server accepts a new upgrade. + dialCh := make(chan struct{}, 8) + + upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + dialCh <- struct{}{} + // Block until the test signals the server to close, or the client disconnects. + select { + case <-closeCh: + // Send a normal close frame so the TCP socket carries the close before + // the server's defer conn.Close() runs. + _ = conn.WriteMessage(ws.CloseMessage, + ws.FormatCloseMessage(ws.CloseNormalClosure, "done")) + time.Sleep(10 * time.Millisecond) + } + })) + defer server.Close() + + config := &schemas.WSPoolConfig{ + MaxIdlePerKey: 5, + MaxTotalConnections: 10, + IdleTimeoutSeconds: 300, + MaxConnectionLifetimeSeconds: 3600, + } + pool := NewPool(config) + defer pool.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL} + + // Dial the first connection and confirm it reaches the mock server. + conn1, err := pool.Get(key, nil) + require.NoError(t, err) + require.NotNil(t, conn1) + <-dialCh // wait until server has accepted connection #1 + + // Return it to the idle pool. + pool.Return(conn1) + + // Signal the server to close the upstream side of connection #1. + close(closeCh) + + // Give the OS a moment to deliver the close frame into the socket buffer. + time.Sleep(50 * time.Millisecond) + + // Pool.Get must detect the stale connection via the liveness probe and dial + // a fresh one. It should not return conn1. + conn2, err := pool.Get(key, nil) + require.NoError(t, err) + require.NotNil(t, conn2) + + // A fresh dial was triggered — wait for the mock server to record it. + select { + case <-dialCh: + // Good: server accepted a new upstream connection. + case <-time.After(2 * time.Second): + t.Fatal("expected a fresh upstream dial after stale-connection eviction, but none arrived") + } + + // The returned connection must not be the stale one. + assert.NotSame(t, conn1, conn2, "Pool.Get must not return the stale session-pinned connection") + assert.True(t, conn1.IsClosed(), "stale connection must have been closed by the pool") + + pool.Discard(conn2) +} + func TestPoolExpiredConnection(t *testing.T) { server := startTestWSServer(t) defer server.Close()