Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions transports/bifrost-http/websocket/pool.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package websocket

import (
"errors"
"fmt"
"net"
"sync"
"time"

Expand Down Expand Up @@ -69,6 +71,19 @@ func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error
continue
}

// Liveness probe: attempt a non-blocking read with a 1ms deadline.
// An idle connection should produce no frames while sitting in the pool.
// If the upstream already sent a close frame, ReadMessage returns
// immediately with a close/EOF error, revealing the stale state before
// the caller uses the connection. A timeout error means the connection
// is alive and has nothing queued.
if !p.isLive(conn) {
conn.Close()
p.mu.Lock()
conns = p.idle[key]
continue
}

p.mu.Lock()
p.inFlight++
p.mu.Unlock()
Expand Down Expand Up @@ -199,6 +214,34 @@ func (p *Pool) evictLoop() {
}
}

// isLive performs a network-level liveness probe on conn.
// It sets a 1ms read deadline and attempts a ReadMessage. A net.Error with
// Timeout() == true means the connection is alive with nothing queued, which
// is the expected state for an idle pooled connection. Any other outcome
// (close frame, EOF, protocol error) means the connection is stale.
// The read deadline is always reset to zero before the function returns so
// that a live connection can be used without deadline constraints.
func (p *Pool) isLive(conn *UpstreamConn) bool {
if err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
return false
}
_, _, err := conn.ReadMessage()
// Always clear the deadline regardless of outcome.
_ = conn.SetReadDeadline(time.Time{})
if err == nil {
// Received a frame while idle — unexpected; treat as stale because
// we cannot safely discard the frame content here.
return false
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// Deadline fired with no data: connection is alive.
return true
}
// Close frame, EOF, or other error: connection is stale.
return false
}

func (p *Pool) evictExpired() {
p.mu.Lock()
defer p.mu.Unlock()
Expand Down
77 changes: 77 additions & 0 deletions transports/bifrost-http/websocket/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,83 @@ func TestPoolClose(t *testing.T) {
assert.Error(t, err)
}

// TestPoolGetEvictsStaleSessionConn verifies that Pool.Get detects a
// server-side close via the liveness probe and dials a fresh connection
// instead of handing out the stale one (issue #3002).
func TestPoolGetEvictsStaleSessionConn(t *testing.T) {
// closeCh is closed by the test to signal the server to close connection #1.
closeCh := make(chan struct{})
// dialCh receives a value each time the mock server accepts a new upgrade.
dialCh := make(chan struct{}, 8)

upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
dialCh <- struct{}{}
// Block until the test signals the server to close, or the client disconnects.
select {
case <-closeCh:
// Send a normal close frame so the TCP socket carries the close before
// the server's defer conn.Close() runs.
_ = conn.WriteMessage(ws.CloseMessage,
ws.FormatCloseMessage(ws.CloseNormalClosure, "done"))
time.Sleep(10 * time.Millisecond)
}
}))
defer server.Close()

config := &schemas.WSPoolConfig{
MaxIdlePerKey: 5,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 300,
MaxConnectionLifetimeSeconds: 3600,
}
pool := NewPool(config)
defer pool.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}

// Dial the first connection and confirm it reaches the mock server.
conn1, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn1)
<-dialCh // wait until server has accepted connection #1

// Return it to the idle pool.
pool.Return(conn1)

// Signal the server to close the upstream side of connection #1.
close(closeCh)

// Give the OS a moment to deliver the close frame into the socket buffer.
time.Sleep(50 * time.Millisecond)

// Pool.Get must detect the stale connection via the liveness probe and dial
// a fresh one. It should not return conn1.
conn2, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn2)

// A fresh dial was triggered — wait for the mock server to record it.
select {
case <-dialCh:
// Good: server accepted a new upstream connection.
case <-time.After(2 * time.Second):
t.Fatal("expected a fresh upstream dial after stale-connection eviction, but none arrived")
}

// The returned connection must not be the stale one.
assert.NotSame(t, conn1, conn2, "Pool.Get must not return the stale session-pinned connection")
assert.True(t, conn1.IsClosed(), "stale connection must have been closed by the pool")

pool.Discard(conn2)
}

func TestPoolExpiredConnection(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
Expand Down