Skip to content

Commit 5ec8def

Browse files
authored
Merge pull request #38 from EternisAI/joeldrotleff/fix-client-dropped-connection
fix deep research disconnect issue
2 parents b68cfe0 + f7bcec1 commit 5ec8def

File tree

4 files changed

+55
-11
lines changed

4 files changed

+55
-11
lines changed

cmd/server/main.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ func main() {
129129
mcpService := mcp.NewService()
130130
searchService := search.NewService(logger.WithComponent("search"))
131131

132-
// Initialize deep research storage
132+
// Initialize deep research storage and session manager
133133
deeprStorage := deepr.NewDBStorage(logger.WithComponent("deepr-storage"), db.DB)
134+
deeprSessionManager := deepr.NewSessionManager(logger.WithComponent("deepr-session"))
134135

135136
// Initialize handlers
136137
oauthHandler := oauth.NewHandler(oauthService, logger.WithComponent("oauth"))
@@ -194,6 +195,7 @@ func main() {
194195
mcpHandler: mcpHandler,
195196
searchHandler: searchHandler,
196197
deeprStorage: deeprStorage,
198+
deeprSessionManager: deeprSessionManager,
197199
})
198200

199201
// Initialize GraphQL server for Telegram
@@ -295,6 +297,7 @@ type restServerInput struct {
295297
mcpHandler *mcp.Handler
296298
searchHandler *search.Handler
297299
deeprStorage deepr.MessageStorage
300+
deeprSessionManager *deepr.SessionManager
298301
}
299302

300303
func setupRESTServer(input restServerInput) *gin.Engine {
@@ -369,7 +372,7 @@ func setupRESTServer(input restServerInput) *gin.Engine {
369372
api.POST("/exa/search", input.searchHandler.PostExaSearchHandler) // POST /api/v1/exa/search (Exa AI)
370373

371374
// Deep Research WebSocket endpoint (protected)
372-
api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient, input.deeprStorage)) // WebSocket proxy for deep research
375+
api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient, input.deeprStorage, input.deeprSessionManager)) // WebSocket proxy for deep research
373376
}
374377

375378
// Protected proxy routes

internal/deepr/handlers.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var upgrader = websocket.Upgrader{
1818
}
1919

2020
// DeepResearchHandler handles WebSocket connections for deep research streaming
21-
func DeepResearchHandler(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient, storage MessageStorage) gin.HandlerFunc {
21+
func DeepResearchHandler(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient, storage MessageStorage, sessionManager *SessionManager) gin.HandlerFunc {
2222
return func(c *gin.Context) {
2323
log := logger.WithContext(c.Request.Context()).WithComponent("deepr")
2424

@@ -74,8 +74,8 @@ func DeepResearchHandler(logger *logger.Logger, trackingService *request_trackin
7474
slog.String("chat_id", chatID),
7575
slog.String("remote_addr", c.Request.RemoteAddr))
7676

77-
// Create service instance with database storage
78-
service := NewService(logger, trackingService, firebaseClient, storage)
77+
// Create service instance with shared session manager
78+
service := NewService(logger, trackingService, firebaseClient, storage, sessionManager)
7979

8080
// Handle the WebSocket connection
8181
service.HandleConnection(c.Request.Context(), conn, userID, chatID)

internal/deepr/service.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,13 @@ func (s *Service) validateFreemiumAccess(ctx context.Context, clientConn *websoc
205205
}
206206

207207
// NewService creates a new deep research service with database storage
208-
func NewService(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient, storage MessageStorage) *Service {
208+
func NewService(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient, storage MessageStorage, sessionManager *SessionManager) *Service {
209209
return &Service{
210210
logger: logger,
211211
trackingService: trackingService,
212212
firebaseClient: firebaseClient,
213213
storage: storage,
214-
sessionManager: NewSessionManager(logger),
214+
sessionManager: sessionManager,
215215
}
216216
}
217217

@@ -229,6 +229,12 @@ func (s *Service) HandleConnection(ctx context.Context, clientConn *websocket.Co
229229
// Check if this is a reconnection
230230
isReconnection := s.sessionManager.HasActiveBackend(userID, chatID)
231231

232+
log.Info("reconnection check performed",
233+
slog.String("user_id", userID),
234+
slog.String("chat_id", chatID),
235+
slog.String("client_id", clientID),
236+
slog.Bool("has_active_backend", isReconnection))
237+
232238
if isReconnection {
233239
log.Info("reconnection to existing session detected",
234240
slog.String("user_id", userID),
@@ -666,9 +672,10 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket
666672
}()
667673
}
668674

669-
// Create session context derived from incoming context but independent of any single client
670-
// This ensures cleanup on server shutdown while allowing session to outlive individual clients
671-
sessionCtx, cancel := context.WithCancel(ctx)
675+
// Create session context independent of any single client's request context
676+
// This allows the backend connection to outlive individual client disconnections
677+
// while still allowing cleanup when the session completes
678+
sessionCtx, cancel := context.WithCancel(context.Background())
672679
defer cancel()
673680

674681
// Create and register session

internal/deepr/session_manager.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ func (sm *SessionManager) CreateSession(userID, chatID string, backendConn *webs
5858

5959
key := sm.getSessionKey(userID, chatID)
6060

61+
// Check if session already exists
62+
if existingSession, exists := sm.sessions[key]; exists {
63+
sm.logger.WithComponent("deepr-session").Warn("OVERWRITING existing session",
64+
slog.String("user_id", userID),
65+
slog.String("chat_id", chatID),
66+
slog.String("session_key", key),
67+
slog.Int("existing_client_count", len(existingSession.clientConns)))
68+
// Cancel the existing session's context
69+
if existingSession.CancelFunc != nil {
70+
existingSession.CancelFunc()
71+
}
72+
}
73+
6174
session := &ActiveSession{
6275
UserID: userID,
6376
ChatID: chatID,
@@ -85,6 +98,11 @@ func (sm *SessionManager) RemoveSession(userID, chatID string) {
8598

8699
key := sm.getSessionKey(userID, chatID)
87100

101+
sm.logger.WithComponent("deepr-session").Info("RemoveSession called",
102+
slog.String("user_id", userID),
103+
slog.String("chat_id", chatID),
104+
slog.String("session_key", key))
105+
88106
if session, exists := sm.sessions[key]; exists {
89107
// Close all client connections
90108
session.mu.Lock()
@@ -238,7 +256,23 @@ func (sm *SessionManager) HasActiveBackend(userID, chatID string) bool {
238256
defer sm.mu.RUnlock()
239257

240258
key := sm.getSessionKey(userID, chatID)
241-
_, exists := sm.sessions[key]
259+
session, exists := sm.sessions[key]
260+
261+
var clientCount int
262+
if exists {
263+
session.mu.RLock()
264+
clientCount = len(session.clientConns)
265+
session.mu.RUnlock()
266+
}
267+
268+
sm.logger.WithComponent("deepr-session").Debug("HasActiveBackend called",
269+
slog.String("user_id", userID),
270+
slog.String("chat_id", chatID),
271+
slog.String("session_key", key),
272+
slog.Bool("session_exists", exists),
273+
slog.Int("client_count", clientCount),
274+
slog.Int("total_sessions", len(sm.sessions)))
275+
242276
return exists
243277
}
244278

0 commit comments

Comments
 (0)