From 803aa2160186ed2868211fa26245393dc1f178ce Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Wed, 8 Oct 2025 00:35:47 -0700 Subject: [PATCH 1/8] feat: deep research check user status --- cmd/server/main.go | 25 +++- docs/deep-research.md | 197 +++++++++++++++++++++++++++++++ internal/auth/firebase_client.go | 157 ++++++++++++++++++++++++ internal/config/config.go | 4 +- internal/deepr/handlers.go | 4 +- internal/deepr/service.go | 44 +++++-- 6 files changed, 418 insertions(+), 13 deletions(-) create mode 100644 docs/deep-research.md create mode 100644 internal/auth/firebase_client.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 6872647..537945f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -99,6 +99,27 @@ func main() { os.Exit(1) } + // Initialize Firebase client for Firestore (used for deep research tracking) + var firebaseClient *auth.FirebaseClient + + if config.AppConfig.FirebaseCredJSON != "" { + firebaseClient, err = auth.NewFirebaseClient(context.Background(), config.AppConfig.FirebaseProjectID, config.AppConfig.FirebaseCredJSON) + if err != nil { + log.Error("failed to initialize firebase client", slog.String("error", err.Error())) + os.Exit(1) + } + log.Info("firebase client initialized") + + // Ensure cleanup on shutdown + defer func() { + if err := firebaseClient.Close(); err != nil { + log.Error("failed to close firebase client", slog.String("error", err.Error())) + } + }() + } else { + log.Warn("firebase credentials not provided - deep research tracking will not work properly") + } + // Initialize services oauthService := oauth.NewService(logger.WithComponent("oauth")) composioService := composio.NewService(logger.WithComponent("composio")) @@ -161,6 +182,7 @@ func main() { router := setupRESTServer(restServerInput{ logger: logger, firebaseAuth: firebaseAuth, + firebaseClient: firebaseClient, requestTrackingService: requestTrackingService, oauthHandler: oauthHandler, composioHandler: composioHandler, @@ -260,6 +282,7 @@ func getKeys(m map[string]string) []string { type restServerInput struct { logger *logger.Logger firebaseAuth *auth.FirebaseAuthMiddleware + firebaseClient *auth.FirebaseClient requestTrackingService *request_tracking.Service oauthHandler *oauth.Handler composioHandler *composio.Handler @@ -341,7 +364,7 @@ func setupRESTServer(input restServerInput) *gin.Engine { api.POST("/exa/search", input.searchHandler.PostExaSearchHandler) // POST /api/v1/exa/search (Exa AI) // Deep Research WebSocket endpoint (protected) - api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService)) // WebSocket proxy for deep research + api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient)) // WebSocket proxy for deep research } // Protected proxy routes diff --git a/docs/deep-research.md b/docs/deep-research.md new file mode 100644 index 0000000..3a60bf3 --- /dev/null +++ b/docs/deep-research.md @@ -0,0 +1,197 @@ +# Deep Research Functionality + +The Deep Research (DeepR) functionality provides a WebSocket-based proxy service that enables users to perform advanced research queries through a dedicated backend service. This feature implements a freemium model with usage tracking and subscription-based access control. + +## Architecture Overview + +The Deep Research functionality consists of three main components: + +1. **WebSocket Handler** (`internal/deepr/handlers.go`) - Manages incoming WebSocket connections +2. **Service Layer** (`internal/deepr/service.go`) - Handles business logic and backend communication +3. **Data Models** (`internal/deepr/models.go`) - Defines message structures +4. **Firebase Integration** (`internal/auth/firebase_client.go`) - Manages usage tracking + +## API Endpoint + +``` +GET /api/deepresearch/ws?chat_id= +``` + +**Authentication**: Required (JWT token) +**Protocol**: WebSocket upgrade from HTTP + +## Request Flow + +### 1. Connection Establishment + +1. Client initiates WebSocket connection to `/api/deepresearch/ws` +2. Server validates JWT authentication +3. Server extracts `chat_id` from query parameters +4. HTTP connection is upgraded to WebSocket +5. Service instance is created with dependencies + +### 2. Subscription Validation + +The system implements a two-tier access model: + +#### Pro Users +- Users with active Pro subscriptions have unlimited access +- Usage is tracked for analytics purposes +- No restrictions on usage frequency + +#### Freemium Users +- Limited to **one free deep research session** per user +- Usage is tracked in Firebase Firestore +- Subsequent attempts are blocked with upgrade prompt + +### 3. Backend Connection + +1. Server constructs WebSocket URL to deep research backend: + ``` + ws://{DEEP_RESEARCH_WS}/deep_research/{user_id}/{chat_id}/ + ``` +2. Establishes connection to external deep research service +3. Creates bidirectional message forwarding + +### 4. Message Relay + +The service acts as a transparent proxy: + +- **Client → Backend**: Forwards all client messages to deep research backend +- **Backend → Client**: Forwards all backend responses to client +- **Error Handling**: Graceful handling of connection drops and errors + +## Data Models + +### Message Structure +```go +type Message struct { + Type string `json:"type"` + Content string `json:"content"` + Data string `json:"data,omitempty"` +} +``` + +### Request Structure +```go +type Request struct { + Query string `json:"query"` + Type string `json:"type"` +} +``` + +### Response Structure +```go +type Response struct { + Type string `json:"type"` + Content string `json:"content"` + Status string `json:"status,omitempty"` +} +``` + +## Usage Tracking + +### Firebase Firestore Schema + +Collection: `deep_research_usage` +Document ID: `{user_id}` + +```go +type DeepResearchUsage struct { + UserID string `firestore:"user_id"` + HasUsedFreeDeepResearch bool `firestore:"has_used_free_deep_research"` + FirstUsedAt time.Time `firestore:"first_used_at"` + LastUsedAt time.Time `firestore:"last_used_at"` + UsageCount int64 `firestore:"usage_count"` +} +``` + +### Tracking Logic + +#### Freemium Users +- `HasUsedFreeDeepResearch`: Set to `true` after first use +- Prevents subsequent free usage +- Error code: `FREE_LIMIT_REACHED` + +#### Pro Users +- `HasUsedFreeDeepResearch`: Always `false` +- `UsageCount`: Incremented for analytics +- No usage restrictions + +## Error Handling + +### Authentication Errors +- **401 Unauthorized**: Missing or invalid JWT token +- **400 Bad Request**: Missing `chat_id` parameter + +### Subscription Errors +- **FREE_LIMIT_REACHED**: Freemium user exceeded free usage +- **Subscription verification failure**: Database errors + +### Connection Errors +- **Backend unavailable**: `DEEP_RESEARCH_WS` not configured +- **Connection failure**: Deep research service unreachable +- **WebSocket errors**: Connection drops and protocol errors + +## Configuration + +### Environment Variables + +- `DEEP_RESEARCH_WS`: Hostname of the deep research backend service +- Firebase configuration for usage tracking + +### Dependencies + +- **Logger**: Structured logging with context +- **Request Tracking Service**: Subscription validation +- **Firebase Client**: Usage tracking and analytics +- **WebSocket Library**: Gorilla WebSocket for connection management + +## Security Considerations + +1. **Authentication**: All connections require valid JWT tokens +2. **Origin Validation**: WebSocket upgrader allows all origins (configurable) +3. **User Isolation**: Each connection is scoped to authenticated user +4. **Chat Isolation**: Messages are scoped to specific chat sessions + +## Monitoring and Logging + +The service provides comprehensive logging: + +- Connection establishment and termination +- Authentication status +- Subscription validation results +- Message forwarding (client ↔ backend) +- Error conditions and stack traces +- Usage tracking events + +All logs include structured fields for user ID, chat ID, and request context. + +## Usage Examples + +### Successful Connection (Pro User) +``` +1. Client connects with valid JWT +2. Server validates Pro subscription +3. Backend connection established +4. Bidirectional message relay begins +``` + +### Freemium User (First Use) +``` +1. Client connects with valid JWT +2. Server checks usage history (none found) +3. Usage record created in Firebase +4. Backend connection established +5. Session proceeds normally +``` + +### Freemium User (Subsequent Use) +``` +1. Client connects with valid JWT +2. Server finds existing usage record +3. Connection rejected with FREE_LIMIT_REACHED error +4. Upgrade prompt sent to client +``` + +This architecture provides a scalable, secure, and monetizable deep research service with proper usage controls and comprehensive monitoring. diff --git a/internal/auth/firebase_client.go b/internal/auth/firebase_client.go new file mode 100644 index 0000000..cc48e7b --- /dev/null +++ b/internal/auth/firebase_client.go @@ -0,0 +1,157 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "cloud.google.com/go/firestore" + firebase "firebase.google.com/go/v4" + "google.golang.org/api/option" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// FirebaseClient wraps Firebase services +type FirebaseClient struct { + firestoreClient *firestore.Client +} + +// NewFirebaseClient creates a new Firebase client with Firestore access +func NewFirebaseClient(ctx context.Context, projectID, credJSON string) (*FirebaseClient, error) { + opt := option.WithCredentialsJSON([]byte(credJSON)) + + // Create Firebase config with project ID + config := &firebase.Config{ + ProjectID: projectID, + } + + app, err := firebase.NewApp(ctx, config, opt) + if err != nil { + return nil, fmt.Errorf("error initializing firebase app: %v", err) + } + + firestoreClient, err := app.Firestore(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get Firestore client: %w", err) + } + + return &FirebaseClient{ + firestoreClient: firestoreClient, + }, nil +} + +// Close closes the Firestore client +func (f *FirebaseClient) Close() error { + if f.firestoreClient != nil { + return f.firestoreClient.Close() + } + return nil +} + +// DeepResearchUsage represents a user's deep research usage record +type DeepResearchUsage struct { + UserID string `firestore:"user_id"` + HasUsedFreeDeepResearch bool `firestore:"has_used_free_deep_research"` + FirstUsedAt time.Time `firestore:"first_used_at"` + LastUsedAt time.Time `firestore:"last_used_at"` + UsageCount int64 `firestore:"usage_count"` +} + +// HasUsedFreeDeepResearch checks if a freemium user has already used deep research +func (f *FirebaseClient) HasUsedFreeDeepResearch(ctx context.Context, userID string) (bool, error) { + docRef := f.firestoreClient.Collection("deep_research_usage").Doc(userID) + doc, err := docRef.Get(ctx) + + if err != nil { + // If document doesn't exist, user hasn't used it yet + if status.Code(err) == codes.NotFound { + return false, nil + } + return false, fmt.Errorf("failed to get deep research usage: %w", err) + } + + var usage DeepResearchUsage + if err := doc.DataTo(&usage); err != nil { + return false, fmt.Errorf("failed to parse deep research usage: %w", err) + } + + return usage.HasUsedFreeDeepResearch, nil +} + +// MarkFreeDeepResearchUsed marks that a freemium user has used their free deep research +func (f *FirebaseClient) MarkFreeDeepResearchUsed(ctx context.Context, userID string) error { + docRef := f.firestoreClient.Collection("deep_research_usage").Doc(userID) + + // Check if document exists + doc, err := docRef.Get(ctx) + now := time.Now() + + if err != nil { + // Document doesn't exist, create new one + usage := DeepResearchUsage{ + UserID: userID, + HasUsedFreeDeepResearch: true, + FirstUsedAt: now, + LastUsedAt: now, + UsageCount: 1, + } + _, err := docRef.Set(ctx, usage) + if err != nil { + return fmt.Errorf("failed to create deep research usage record: %w", err) + } + return nil + } + + // Document exists, update it + var usage DeepResearchUsage + if err := doc.DataTo(&usage); err != nil { + return fmt.Errorf("failed to parse existing usage record: %w", err) + } + + // Update the record + _, err = docRef.Set(ctx, map[string]interface{}{ + "has_used_free_deep_research": true, + "last_used_at": now, + "usage_count": usage.UsageCount + 1, + }, firestore.MergeAll) + + if err != nil { + return fmt.Errorf("failed to update deep research usage record: %w", err) + } + + return nil +} + +// IncrementDeepResearchUsage increments usage counter for pro users (for analytics) +func (f *FirebaseClient) IncrementDeepResearchUsage(ctx context.Context, userID string) error { + docRef := f.firestoreClient.Collection("deep_research_usage").Doc(userID) + now := time.Now() + + doc, err := docRef.Get(ctx) + if err != nil { + // Create new record for pro user + usage := DeepResearchUsage{ + UserID: userID, + HasUsedFreeDeepResearch: false, // Pro users don't count as "free" usage + FirstUsedAt: now, + LastUsedAt: now, + UsageCount: 1, + } + _, err := docRef.Set(ctx, usage) + return err + } + + var usage DeepResearchUsage + if err := doc.DataTo(&usage); err != nil { + return fmt.Errorf("failed to parse usage record: %w", err) + } + + // Update usage count and last used time + _, err = docRef.Set(ctx, map[string]interface{}{ + "last_used_at": now, + "usage_count": usage.UsageCount + 1, + }, firestore.MergeAll) + + return err +} diff --git a/internal/config/config.go b/internal/config/config.go index 171f39e..cb5857c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -89,7 +89,7 @@ var AppConfig *Config func LoadConfig() { // Load .env file if it exists - if err := godotenv.Load(); err != nil { + if err := godotenv.Load(".env"); err != nil { log.Println("No .env file found, using environment variables") } @@ -98,7 +98,7 @@ func LoadConfig() { GinMode: getEnvOrDefault("GIN_MODE", "release"), // Firebase - FirebaseProjectID: getEnvOrDefault("FIREBASE_PROJECT_ID", "enchanted-login-8fdb9"), + FirebaseProjectID: getEnvOrDefault("FIREBASE_PROJECT_ID", "silo-dev-95230"), // Database DatabaseURL: getEnvOrDefault("DATABASE_URL", "postgres://localhost/tee_api?sslmode=disable"), diff --git a/internal/deepr/handlers.go b/internal/deepr/handlers.go index fc6d9c8..7195782 100644 --- a/internal/deepr/handlers.go +++ b/internal/deepr/handlers.go @@ -18,7 +18,7 @@ var upgrader = websocket.Upgrader{ } // DeepResearchHandler handles WebSocket connections for deep research streaming -func DeepResearchHandler(logger *logger.Logger, trackingService *request_tracking.Service) gin.HandlerFunc { +func DeepResearchHandler(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient) gin.HandlerFunc { return func(c *gin.Context) { log := logger.WithContext(c.Request.Context()).WithComponent("deepr") @@ -60,7 +60,7 @@ func DeepResearchHandler(logger *logger.Logger, trackingService *request_trackin slog.String("chat_id", chatID)) // Create service instance - service := NewService(logger, trackingService) + service := NewService(logger, trackingService, firebaseClient) // Handle the WebSocket connection service.HandleConnection(c.Request.Context(), conn, userID, chatID) diff --git a/internal/deepr/service.go b/internal/deepr/service.go index a55a54a..75726ee 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -6,6 +6,7 @@ import ( "net/url" "os" + "github.com/eternisai/enchanted-proxy/internal/auth" "github.com/eternisai/enchanted-proxy/internal/logger" "github.com/eternisai/enchanted-proxy/internal/request_tracking" "github.com/gorilla/websocket" @@ -15,13 +16,15 @@ import ( type Service struct { logger *logger.Logger trackingService *request_tracking.Service + firebaseClient *auth.FirebaseClient } // NewService creates a new deep research service -func NewService(logger *logger.Logger, trackingService *request_tracking.Service) *Service { +func NewService(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient) *Service { return &Service{ logger: logger, trackingService: trackingService, + firebaseClient: firebaseClient, } } @@ -42,16 +45,41 @@ func (s *Service) HandleConnection(ctx context.Context, clientConn *websocket.Co log.Info("user has active pro subscription", slog.String("user_id", userID), slog.Time("expires_at", *proExpiresAt)) + + // Track usage for pro users (for analytics) + if err := s.firebaseClient.IncrementDeepResearchUsage(ctx, userID); err != nil { + log.Error("failed to track pro user deep research usage", slog.String("error", err.Error())) + // Don't block pro users on tracking error + } } else { + // Freemium user - check if they've already used their free deep research log.Info("user is on freemium plan", slog.String("user_id", userID)) - } - // TODO: Implement subscription-based access control - // For now, allow both pro and freemium users - // In the future, you might want to: - // - Limit freemium users to certain features - // - Rate limit freemium users - // - Block freemium users from certain endpoints + hasUsed, err := s.firebaseClient.HasUsedFreeDeepResearch(ctx, userID) + if err != nil { + log.Error("failed to check freemium deep research usage", slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify usage status"}`)) + clientConn.Close() + return + } + + if hasUsed { + log.Info("freemium user has already used their free deep research", slog.String("user_id", userID)) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "You have already used your free deep research. Please upgrade to Pro for unlimited access.", "error_code": "FREE_LIMIT_REACHED"}`)) + clientConn.Close() + return + } + + // Mark that the freemium user has now used their free deep research + if err := s.firebaseClient.MarkFreeDeepResearchUsed(ctx, userID); err != nil { + log.Error("failed to mark freemium deep research as used", slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to track usage"}`)) + clientConn.Close() + return + } + + log.Info("freemium user is using their free deep research", slog.String("user_id", userID)) + } // Construct WebSocket URL for the deep research server deepResearchHost := os.Getenv("DEEP_RESEARCH_WS") From da467888872a7675ae98323743c55eff832ec7e3 Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Wed, 8 Oct 2025 16:28:06 -0700 Subject: [PATCH 2/8] feat: saving state proxy server --- docs/deep-research-reconnection.md | 253 +++++++++++++++++++++++++++++ docs/deep-research.md | 13 ++ internal/deepr/models.go | 32 +++- internal/deepr/service.go | 245 ++++++++++++++++++++++++---- internal/deepr/session_manager.go | 211 ++++++++++++++++++++++++ internal/deepr/storage.go | 250 ++++++++++++++++++++++++++++ 6 files changed, 965 insertions(+), 39 deletions(-) create mode 100644 docs/deep-research-reconnection.md create mode 100644 internal/deepr/session_manager.go create mode 100644 internal/deepr/storage.go diff --git a/docs/deep-research-reconnection.md b/docs/deep-research-reconnection.md new file mode 100644 index 0000000..1ecbd71 --- /dev/null +++ b/docs/deep-research-reconnection.md @@ -0,0 +1,253 @@ +# Deep Research Reconnection Feature + +## Overview + +The Deep Research service now supports automatic session persistence and reconnection, allowing iOS clients to disconnect and reconnect without losing progress or messages. + +## How It Works + +### Architecture + +1. **Session Persistence**: All messages from the deep research backend are stored in JSON files on the enchanted-proxy server +2. **Message State Tracking**: Each message is marked as "sent" or "unsent" (sent to iOS client) +3. **Backend Connection Persistence**: The connection between enchanted-proxy and deep research backend remains active even when iOS disconnects +4. **Automatic Reconnection**: When iOS reconnects, it receives all unsent messages and continues listening for new ones + +### Message Flow + +#### Initial Connection +``` +iOS App → WebSocket → Enchanted Proxy → WebSocket → Deep Research Backend + ↓ + JSON Storage + (Messages persisted) +``` + +#### iOS Disconnects +``` +Deep Research Backend → Enchanted Proxy → JSON Storage (messages marked as "unsent") + (continues running) ↓ + Backend stays connected +``` + +#### iOS Reconnects +``` +iOS App → WebSocket → Enchanted Proxy + ↓ + 1. Loads session state + 2. Sends unsent messages + 3. Continues receiving new messages +``` + +## Key Components + +### 1. Storage Layer (`storage.go`) + +Handles persistence of messages and session state: + +- **SessionState**: Tracks session metadata (backend connection status, completion status, etc.) +- **PersistedMessage**: Individual message with sent/unsent status +- **Storage Methods**: + - `LoadSession()`: Load session state from disk + - `SaveSession()`: Save session state to disk + - `AddMessage()`: Store new message with sent status + - `GetUnsentMessages()`: Retrieve messages not yet sent to client + - `MarkMessageAsSent()`: Mark message as delivered + - `IsSessionComplete()`: Check if research is complete + +### 2. Session Manager (`session_manager.go`) + +Manages active backend connections: + +- **ActiveSession**: Represents a live backend connection with multiple client connections +- **SessionManager Methods**: + - `CreateSession()`: Create new backend session + - `GetSession()`: Retrieve active session + - `AddClientConnection()`: Add iOS client to session + - `RemoveClientConnection()`: Remove iOS client from session + - `BroadcastToClients()`: Send message to all connected clients + - `HasActiveBackend()`: Check if backend connection exists + +### 3. Service Updates (`service.go`) + +Enhanced service logic: + +- **Reconnection Detection**: Checks if backend session exists for userID/chatID +- **Message Persistence**: Stores every message with sent/unsent status +- **Unsent Message Delivery**: Sends accumulated messages on reconnection +- **Session Completion Detection**: Recognizes final reports and errors + +## Session States + +### Message States + +1. **sent: true**: Message successfully delivered to iOS client +2. **sent: false**: Message received from backend but not yet delivered to iOS + +### Session Completion Conditions + +A session is considered complete when: +1. **Final Report Received**: Message contains `final_report` field with content +2. **Error Occurred**: Message has `type: "error"` or contains `error` field + +## Storage Format + +### Session File Location + +Default: `./deepr_sessions/session_{userID}_{chatID}.json` + +Can be configured via environment variable: `DEEPR_STORAGE_PATH` + +### Session File Structure + +```json +{ + "user_id": "abc123", + "chat_id": "chat456", + "messages": [ + { + "id": "msg-uuid-1", + "user_id": "abc123", + "chat_id": "chat456", + "message": "{\"type\":\"status\",\"content\":\"Starting research...\"}", + "sent": true, + "timestamp": "2025-10-08T10:30:00Z", + "message_type": "status" + }, + { + "id": "msg-uuid-2", + "user_id": "abc123", + "chat_id": "chat456", + "message": "{\"type\":\"update\",\"content\":\"Analyzing sources...\"}", + "sent": false, + "timestamp": "2025-10-08T10:31:00Z", + "message_type": "update" + } + ], + "backend_connected": true, + "last_activity": "2025-10-08T10:31:00Z", + "final_report_received": false, + "error_occurred": false +} +``` + +## Reconnection Behavior + +### On Reconnection + +1. **Check Session**: Determine if backend session exists +2. **Send Unsent Messages**: Deliver all messages marked as unsent +3. **Check Completion**: + - If `final_report_received` or `error_occurred`: Send final message and close + - Otherwise: Continue listening for new messages +4. **Join Active Session**: Add client to existing session for real-time updates + +### Multiple Clients + +The system supports multiple iOS clients connected to the same session: + +- Messages are broadcast to all connected clients +- Each client can disconnect/reconnect independently +- Backend connection persists as long as research is ongoing + +## Configuration + +### Environment Variables + +```bash +# Storage path for session files (optional) +DEEPR_STORAGE_PATH=/path/to/sessions + +# Deep research backend WebSocket URL (required) +DEEP_RESEARCH_WS=your-backend-host:port +``` + +### Default Values + +- Storage Path: `./deepr_sessions` +- Session file naming: `session_{userID}_{chatID}.json` + +## Cleanup + +Session files should be cleaned up periodically using: + +```go +storage.CleanupOldSessions(maxAge time.Duration) +``` + +Recommended: Clean up sessions older than 24-48 hours + +## Error Handling + +### Storage Failures + +If storage operations fail: +- Error is logged +- Message delivery continues +- Reconnection may not have full history + +### Backend Disconnection + +If backend connection drops: +- Session marked as disconnected +- Existing messages remain available +- New connection attempts will create fresh session + +### Client Disconnection + +If iOS client disconnects: +- Backend connection remains active +- Messages continue to be stored as unsent +- Client can reconnect anytime + +## Message Types + +The system recognizes these message types: + +1. **status**: Progress updates +2. **update**: Research updates +3. **error**: Error messages (marks session as complete) +4. **final**: Messages with `final_report` field (marks session as complete) + +## Usage Example + +### iOS Client Flow + +``` +1. Connect to /api/deepresearch/ws?chat_id=123 +2. Send research query +3. Receive status updates +4. (App backgrounded/disconnected) +5. (Messages continue arriving at backend) +6. Reconnect to /api/deepresearch/ws?chat_id=123 +7. Receive all unsent messages +8. Continue receiving new messages +9. Receive final report +10. Session complete +``` + +## Monitoring + +Log messages indicate: +- Session creation/removal +- Client connections/disconnections +- Message persistence status +- Reconnection events +- Unsent message delivery + +Example log: +``` +INFO: created new session user_id=abc123 chat_id=chat456 +INFO: message stored sent=false type=status +INFO: detected reconnection user_id=abc123 chat_id=chat456 +INFO: sending unsent messages count=5 +``` + +## Testing Recommendations + +1. **Test Disconnection**: Kill iOS app during research +2. **Test Reconnection**: Reopen app and verify messages received +3. **Test Multiple Clients**: Connect from multiple devices +4. **Test Error Handling**: Verify error messages mark session complete +5. **Test Final Report**: Verify final report marks session complete +6. **Test Storage**: Check session files are created correctly diff --git a/docs/deep-research.md b/docs/deep-research.md index 3a60bf3..17709e4 100644 --- a/docs/deep-research.md +++ b/docs/deep-research.md @@ -195,3 +195,16 @@ All logs include structured fields for user ID, chat ID, and request context. ``` This architecture provides a scalable, secure, and monetizable deep research service with proper usage controls and comprehensive monitoring. + +## Reconnection and Session Persistence + +The deep research service now supports automatic session persistence and reconnection. This allows iOS clients to disconnect (e.g., app backgrounded or killed) and reconnect without losing progress. + +**Key Features:** +- **Session Persistence**: Messages are stored to disk as they arrive from the backend +- **Message State Tracking**: Each message is marked as sent/unsent (delivered to iOS or not) +- **Backend Persistence**: Connection to deep research backend stays active even when iOS disconnects +- **Automatic Recovery**: On reconnection, all unsent messages are delivered automatically +- **Session Completion Detection**: Recognizes when research is complete (final report or error) + +For detailed information about the reconnection feature, see [Deep Research Reconnection Documentation](./deep-research-reconnection.md). diff --git a/internal/deepr/models.go b/internal/deepr/models.go index f5dd953..9e6d621 100644 --- a/internal/deepr/models.go +++ b/internal/deepr/models.go @@ -1,10 +1,14 @@ package deepr +import "time" + // Message represents a WebSocket message for deep research type Message struct { - Type string `json:"type"` - Content string `json:"content"` - Data string `json:"data,omitempty"` + Type string `json:"type"` + Content string `json:"content"` + Data string `json:"data,omitempty"` + FinalReport string `json:"final_report,omitempty"` + Error string `json:"error,omitempty"` } // Request represents a request to the deep research service @@ -19,3 +23,25 @@ type Response struct { Content string `json:"content"` Status string `json:"status,omitempty"` } + +// PersistedMessage represents a message stored to disk +type PersistedMessage struct { + ID string `json:"id"` + UserID string `json:"user_id"` + ChatID string `json:"chat_id"` + Message string `json:"message"` + Sent bool `json:"sent"` + Timestamp time.Time `json:"timestamp"` + MessageType string `json:"message_type"` // "status", "error", "final", etc. +} + +// SessionState represents the state of a deep research session +type SessionState struct { + UserID string `json:"user_id"` + ChatID string `json:"chat_id"` + Messages []PersistedMessage `json:"messages"` + BackendConnected bool `json:"backend_connected"` + LastActivity time.Time `json:"last_activity"` + FinalReportReceived bool `json:"final_report_received"` + ErrorOccurred bool `json:"error_occurred"` +} diff --git a/internal/deepr/service.go b/internal/deepr/service.go index 75726ee..0043e4c 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -2,13 +2,16 @@ package deepr import ( "context" + "encoding/json" "log/slog" "net/url" "os" + "path/filepath" "github.com/eternisai/enchanted-proxy/internal/auth" "github.com/eternisai/enchanted-proxy/internal/logger" "github.com/eternisai/enchanted-proxy/internal/request_tracking" + "github.com/google/uuid" "github.com/gorilla/websocket" ) @@ -17,28 +20,76 @@ type Service struct { logger *logger.Logger trackingService *request_tracking.Service firebaseClient *auth.FirebaseClient + storage *Storage + sessionManager *SessionManager } // NewService creates a new deep research service func NewService(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient) *Service { + // Get storage path from environment or use default + storagePath := os.Getenv("DEEPR_STORAGE_PATH") + if storagePath == "" { + storagePath = filepath.Join(".", "deepr_sessions") + } + + storage, err := NewStorage(logger, storagePath) + if err != nil { + logger.WithComponent("deepr").Error("failed to create storage, using in-memory only", + slog.String("error", err.Error())) + } + return &Service{ logger: logger, trackingService: trackingService, firebaseClient: firebaseClient, + storage: storage, + sessionManager: NewSessionManager(logger), } } // HandleConnection manages the WebSocket connection and streaming func (s *Service) HandleConnection(ctx context.Context, clientConn *websocket.Conn, userID, chatID string) { log := s.logger.WithContext(ctx).WithComponent("deepr") + clientID := uuid.New().String() + + log.Info("handling new client connection", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("client_id", clientID)) + + // Check if this is a reconnection + isReconnection := s.sessionManager.HasActiveBackend(userID, chatID) + + if isReconnection { + log.Info("detected reconnection to existing session", + slog.String("user_id", userID), + slog.String("chat_id", chatID)) + + // Handle reconnection + s.handleReconnection(ctx, clientConn, userID, chatID, clientID) + return + } + + // New connection - perform subscription checks + if err := s.checkAndTrackSubscription(ctx, clientConn, userID); err != nil { + log.Error("subscription check failed", slog.String("error", err.Error())) + clientConn.Close() + return + } + + // Create new backend connection + s.handleNewConnection(ctx, clientConn, userID, chatID, clientID) +} + +// checkAndTrackSubscription checks user subscription and tracks usage +func (s *Service) checkAndTrackSubscription(ctx context.Context, clientConn *websocket.Conn, userID string) error { + log := s.logger.WithContext(ctx).WithComponent("deepr") - // Check user subscription status hasActivePro, proExpiresAt, err := s.trackingService.HasActivePro(ctx, userID) if err != nil { log.Error("failed to check user subscription status", slog.String("error", err.Error())) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify subscription status"}`)) - clientConn.Close() - return + return err } if hasActivePro { @@ -46,45 +97,125 @@ func (s *Service) HandleConnection(ctx context.Context, clientConn *websocket.Co slog.String("user_id", userID), slog.Time("expires_at", *proExpiresAt)) - // Track usage for pro users (for analytics) if err := s.firebaseClient.IncrementDeepResearchUsage(ctx, userID); err != nil { log.Error("failed to track pro user deep research usage", slog.String("error", err.Error())) - // Don't block pro users on tracking error } } else { - // Freemium user - check if they've already used their free deep research log.Info("user is on freemium plan", slog.String("user_id", userID)) hasUsed, err := s.firebaseClient.HasUsedFreeDeepResearch(ctx, userID) if err != nil { log.Error("failed to check freemium deep research usage", slog.String("error", err.Error())) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify usage status"}`)) - clientConn.Close() - return + return err } if hasUsed { log.Info("freemium user has already used their free deep research", slog.String("user_id", userID)) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "You have already used your free deep research. Please upgrade to Pro for unlimited access.", "error_code": "FREE_LIMIT_REACHED"}`)) - clientConn.Close() - return + return err } - // Mark that the freemium user has now used their free deep research if err := s.firebaseClient.MarkFreeDeepResearchUsed(ctx, userID); err != nil { log.Error("failed to mark freemium deep research as used", slog.String("error", err.Error())) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to track usage"}`)) - clientConn.Close() - return + return err } log.Info("freemium user is using their free deep research", slog.String("user_id", userID)) } - // Construct WebSocket URL for the deep research server + return nil +} + +// handleReconnection handles a client reconnecting to an existing session +func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket.Conn, userID, chatID, clientID string) { + log := s.logger.WithContext(ctx).WithComponent("deepr") + + log.Info("handling reconnection", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("client_id", clientID)) + + // Add client to session manager + s.sessionManager.AddClientConnection(userID, chatID, clientID, clientConn) + defer s.sessionManager.RemoveClientConnection(userID, chatID, clientID) + + // Check if session is complete + if s.storage != nil { + isComplete, err := s.storage.IsSessionComplete(userID, chatID) + if err != nil { + log.Error("failed to check session completion status", slog.String("error", err.Error())) + } + + // Send unsent messages + unsent, err := s.storage.GetUnsentMessages(userID, chatID) + if err != nil { + log.Error("failed to get unsent messages", slog.String("error", err.Error())) + } else if len(unsent) > 0 { + log.Info("sending unsent messages to reconnected client", + slog.Int("count", len(unsent))) + + for _, msg := range unsent { + if err := clientConn.WriteMessage(websocket.TextMessage, []byte(msg.Message)); err != nil { + log.Error("failed to send unsent message", slog.String("error", err.Error())) + return + } + // Mark as sent + if err := s.storage.MarkMessageAsSent(userID, chatID, msg.ID); err != nil { + log.Error("failed to mark message as sent", slog.String("error", err.Error())) + } + } + } + + if isComplete { + log.Info("session is complete, no more messages expected") + return + } + } + + // Listen for new messages from backend (they'll be broadcast to all clients) + done := make(chan struct{}) + + // Listen for messages from this client + go func() { + defer close(done) + for { + select { + case <-ctx.Done(): + return + default: + _, message, err := clientConn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Error("error reading from reconnected client", slog.String("error", err.Error())) + } + return + } + + log.Info("received message from reconnected client", slog.String("message", string(message))) + + // Forward to backend if session exists + if session, exists := s.sessionManager.GetSession(userID, chatID); exists && session.BackendConn != nil { + if err := session.BackendConn.WriteMessage(websocket.TextMessage, message); err != nil { + log.Error("error forwarding message to backend", slog.String("error", err.Error())) + return + } + } + } + } + }() + + <-done +} + +// handleNewConnection creates a new backend connection and manages message flow +func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket.Conn, userID, chatID, clientID string) { + log := s.logger.WithContext(ctx).WithComponent("deepr") + deepResearchHost := os.Getenv("DEEP_RESEARCH_WS") if deepResearchHost == "" { - log.Error("❌ [DeepResearch] DEEP_RESEARCH_WS environment variable not set") + log.Error("DEEP_RESEARCH_WS environment variable not set") clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Deep research backend not configured"}`)) return } @@ -95,76 +226,118 @@ func (s *Service) HandleConnection(ctx context.Context, clientConn *websocket.Co Path: "/deep_research/" + userID + "/" + chatID + "/", } - log.Info("🔌 [DeepResearch] connecting to deep research server", slog.String("url", wsURL.String())) + log.Info("connecting to deep research server", slog.String("url", wsURL.String())) - // Connect to the deep research server serverConn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) if err != nil { - log.Error("❌ [DeepResearch] failed to connect to deep research server", slog.String("error", err.Error())) + log.Error("failed to connect to deep research server", slog.String("error", err.Error())) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to connect to deep research backend"}`)) return } defer serverConn.Close() - log.Info("✅ [DeepResearch] Connected to deep research backend") + log.Info("connected to deep research backend") + + // Update storage + if s.storage != nil { + if err := s.storage.UpdateBackendConnectionStatus(userID, chatID, true); err != nil { + log.Error("failed to update backend connection status", slog.String("error", err.Error())) + } + defer func() { + if err := s.storage.UpdateBackendConnectionStatus(userID, chatID, false); err != nil { + log.Error("failed to update backend disconnection status", slog.String("error", err.Error())) + } + }() + } + + // Create session context + sessionCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Create and register session + _ = s.sessionManager.CreateSession(userID, chatID, serverConn, sessionCtx, cancel) + defer s.sessionManager.RemoveSession(userID, chatID) + + // Add initial client + s.sessionManager.AddClientConnection(userID, chatID, clientID, clientConn) - // Create channels for communication done := make(chan struct{}) - // Start goroutine to handle messages from client to server + // Handle messages from client to backend go func() { defer close(done) for { select { - case <-ctx.Done(): + case <-sessionCtx.Done(): return default: _, message, err := clientConn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Error("❌ [DeepResearch] error reading from client", slog.String("error", err.Error())) + log.Error("error reading from client", slog.String("error", err.Error())) } + s.sessionManager.RemoveClientConnection(userID, chatID, clientID) return } - log.Info("📨 [DeepResearch] Received message from client", slog.String("message", string(message))) + log.Info("received message from client", slog.String("message", string(message))) - // Forward message to server if err := serverConn.WriteMessage(websocket.TextMessage, message); err != nil { - log.Error("❌ [DeepResearch] error writing to server", slog.String("error", err.Error())) + log.Error("error writing to server", slog.String("error", err.Error())) return } - log.Info("📤 [DeepResearch] message forwarded to deep research backend") + log.Info("message forwarded to deep research backend") } } }() - // Handle messages from server to client + // Handle messages from backend to clients for { select { case <-done: return - case <-ctx.Done(): + case <-sessionCtx.Done(): return default: _, message, err := serverConn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Error("❌ [DeepResearch] error reading from backend", slog.String("error", err.Error())) + log.Error("error reading from backend", slog.String("error", err.Error())) } return } - log.Info("📨 [DeepResearch] Received message from backend", slog.String("message", string(message))) + log.Info("received message from backend", slog.String("message", string(message))) - // Forward message to client - if err := clientConn.WriteMessage(websocket.TextMessage, message); err != nil { - log.Error("❌ [DeepResearch] error writing to client", slog.String("error", err.Error())) - return + // Determine message type + var msg Message + messageType := "status" + if err := json.Unmarshal(message, &msg); err == nil { + if msg.Type != "" { + messageType = msg.Type + } } - log.Info("📤 [DeepResearch] message forwarded to client") + // Store message + messageSent := false + if s.storage != nil { + // Try to broadcast to clients + broadcastErr := s.sessionManager.BroadcastToClients(userID, chatID, message) + messageSent = (broadcastErr == nil && s.sessionManager.GetClientCount(userID, chatID) > 0) + + // Store message with sent status + if err := s.storage.AddMessage(userID, chatID, string(message), messageSent, messageType); err != nil { + log.Error("failed to store message", slog.String("error", err.Error())) + } else { + log.Info("message stored", + slog.Bool("sent", messageSent), + slog.String("type", messageType)) + } + } else { + // No storage, just broadcast + s.sessionManager.BroadcastToClients(userID, chatID, message) + } } } } diff --git a/internal/deepr/session_manager.go b/internal/deepr/session_manager.go new file mode 100644 index 0000000..f34bc0b --- /dev/null +++ b/internal/deepr/session_manager.go @@ -0,0 +1,211 @@ +package deepr + +import ( + "context" + "log/slog" + "sync" + + "github.com/eternisai/enchanted-proxy/internal/logger" + "github.com/gorilla/websocket" +) + +// ActiveSession represents an active backend connection +type ActiveSession struct { + UserID string + ChatID string + BackendConn *websocket.Conn + Context context.Context + CancelFunc context.CancelFunc + mu sync.RWMutex + clientConns map[string]*websocket.Conn // Map of client connection IDs +} + +// SessionManager manages active backend connections +type SessionManager struct { + logger *logger.Logger + sessions map[string]*ActiveSession // key: "userID:chatID" + mu sync.RWMutex +} + +// NewSessionManager creates a new session manager +func NewSessionManager(logger *logger.Logger) *SessionManager { + return &SessionManager{ + logger: logger, + sessions: make(map[string]*ActiveSession), + } +} + +// getSessionKey generates a session key from userID and chatID +func (sm *SessionManager) getSessionKey(userID, chatID string) string { + return userID + ":" + chatID +} + +// GetSession retrieves an active session +func (sm *SessionManager) GetSession(userID, chatID string) (*ActiveSession, bool) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + key := sm.getSessionKey(userID, chatID) + session, exists := sm.sessions[key] + return session, exists +} + +// CreateSession creates a new active session +func (sm *SessionManager) CreateSession(userID, chatID string, backendConn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *ActiveSession { + sm.mu.Lock() + defer sm.mu.Unlock() + + key := sm.getSessionKey(userID, chatID) + + session := &ActiveSession{ + UserID: userID, + ChatID: chatID, + BackendConn: backendConn, + Context: ctx, + CancelFunc: cancel, + clientConns: make(map[string]*websocket.Conn), + } + + sm.sessions[key] = session + + sm.logger.WithComponent("deepr-session").Info("created new session", + slog.String("user_id", userID), + slog.String("chat_id", chatID)) + + return session +} + +// RemoveSession removes a session +func (sm *SessionManager) RemoveSession(userID, chatID string) { + sm.mu.Lock() + defer sm.mu.Unlock() + + key := sm.getSessionKey(userID, chatID) + + if session, exists := sm.sessions[key]; exists { + // Close all client connections + session.mu.Lock() + for clientID, conn := range session.clientConns { + conn.Close() + sm.logger.WithComponent("deepr-session").Info("closed client connection during session removal", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("client_id", clientID)) + } + session.clientConns = make(map[string]*websocket.Conn) + session.mu.Unlock() + + // Cancel context + if session.CancelFunc != nil { + session.CancelFunc() + } + + delete(sm.sessions, key) + + sm.logger.WithComponent("deepr-session").Info("removed session", + slog.String("user_id", userID), + slog.String("chat_id", chatID)) + } +} + +// AddClientConnection adds a client connection to an existing session +func (sm *SessionManager) AddClientConnection(userID, chatID, clientID string, conn *websocket.Conn) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + key := sm.getSessionKey(userID, chatID) + if session, exists := sm.sessions[key]; exists { + session.mu.Lock() + session.clientConns[clientID] = conn + session.mu.Unlock() + + sm.logger.WithComponent("deepr-session").Info("added client connection to session", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("client_id", clientID)) + } +} + +// RemoveClientConnection removes a client connection from a session +func (sm *SessionManager) RemoveClientConnection(userID, chatID, clientID string) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + key := sm.getSessionKey(userID, chatID) + if session, exists := sm.sessions[key]; exists { + session.mu.Lock() + delete(session.clientConns, clientID) + clientCount := len(session.clientConns) + session.mu.Unlock() + + sm.logger.WithComponent("deepr-session").Info("removed client connection from session", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("client_id", clientID), + slog.Int("remaining_clients", clientCount)) + } +} + +// BroadcastToClients sends a message to all connected clients for a session +func (sm *SessionManager) BroadcastToClients(userID, chatID string, message []byte) error { + sm.mu.RLock() + key := sm.getSessionKey(userID, chatID) + session, exists := sm.sessions[key] + sm.mu.RUnlock() + + if !exists { + return nil // No active session, message will be stored as unsent + } + + session.mu.RLock() + defer session.mu.RUnlock() + + var lastErr error + sentCount := 0 + + for clientID, conn := range session.clientConns { + if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { + sm.logger.WithComponent("deepr-session").Error("failed to send message to client", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("client_id", clientID), + slog.String("error", err.Error())) + lastErr = err + } else { + sentCount++ + } + } + + sm.logger.WithComponent("deepr-session").Debug("broadcast message to clients", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Int("sent_count", sentCount), + slog.Int("total_clients", len(session.clientConns))) + + return lastErr +} + +// GetClientCount returns the number of connected clients for a session +func (sm *SessionManager) GetClientCount(userID, chatID string) int { + sm.mu.RLock() + defer sm.mu.RUnlock() + + key := sm.getSessionKey(userID, chatID) + if session, exists := sm.sessions[key]; exists { + session.mu.RLock() + defer session.mu.RUnlock() + return len(session.clientConns) + } + + return 0 +} + +// HasActiveBackend checks if there's an active backend connection for a session +func (sm *SessionManager) HasActiveBackend(userID, chatID string) bool { + sm.mu.RLock() + defer sm.mu.RUnlock() + + key := sm.getSessionKey(userID, chatID) + _, exists := sm.sessions[key] + return exists +} diff --git a/internal/deepr/storage.go b/internal/deepr/storage.go new file mode 100644 index 0000000..5309085 --- /dev/null +++ b/internal/deepr/storage.go @@ -0,0 +1,250 @@ +package deepr + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" + "time" + + "github.com/eternisai/enchanted-proxy/internal/logger" + "github.com/google/uuid" +) + +// Storage handles persistence of deep research messages +type Storage struct { + logger *logger.Logger + storagePath string + mu sync.RWMutex +} + +// NewStorage creates a new storage instance +func NewStorage(logger *logger.Logger, storagePath string) (*Storage, error) { + // Create storage directory if it doesn't exist + if err := os.MkdirAll(storagePath, 0755); err != nil { + return nil, fmt.Errorf("failed to create storage directory: %w", err) + } + + return &Storage{ + logger: logger, + storagePath: storagePath, + }, nil +} + +// getSessionFilePath returns the file path for a session +func (s *Storage) getSessionFilePath(userID, chatID string) string { + filename := fmt.Sprintf("session_%s_%s.json", userID, chatID) + return filepath.Join(s.storagePath, filename) +} + +// LoadSession loads a session state from disk +func (s *Storage) LoadSession(userID, chatID string) (*SessionState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + filePath := s.getSessionFilePath(userID, chatID) + + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + // No existing session, return new one + return &SessionState{ + UserID: userID, + ChatID: chatID, + Messages: []PersistedMessage{}, + BackendConnected: false, + LastActivity: time.Now(), + FinalReportReceived: false, + ErrorOccurred: false, + }, nil + } + return nil, fmt.Errorf("failed to read session file: %w", err) + } + + var state SessionState + if err := json.Unmarshal(data, &state); err != nil { + return nil, fmt.Errorf("failed to unmarshal session: %w", err) + } + + return &state, nil +} + +// SaveSession saves a session state to disk +func (s *Storage) SaveSession(state *SessionState) error { + s.mu.Lock() + defer s.mu.Unlock() + + filePath := s.getSessionFilePath(state.UserID, state.ChatID) + + state.LastActivity = time.Now() + + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal session: %w", err) + } + + if err := os.WriteFile(filePath, data, 0644); err != nil { + return fmt.Errorf("failed to write session file: %w", err) + } + + return nil +} + +// AddMessage adds a new message to the session +func (s *Storage) AddMessage(userID, chatID, message string, sent bool, messageType string) error { + state, err := s.LoadSession(userID, chatID) + if err != nil { + return err + } + + persistedMsg := PersistedMessage{ + ID: uuid.New().String(), + UserID: userID, + ChatID: chatID, + Message: message, + Sent: sent, + Timestamp: time.Now(), + MessageType: messageType, + } + + state.Messages = append(state.Messages, persistedMsg) + + // Check if this is a final report or error + var msg Message + if err := json.Unmarshal([]byte(message), &msg); err == nil { + if msg.FinalReport != "" { + state.FinalReportReceived = true + } + if msg.Type == "error" || msg.Error != "" { + state.ErrorOccurred = true + } + } + + return s.SaveSession(state) +} + +// MarkMessageAsSent marks a specific message as sent +func (s *Storage) MarkMessageAsSent(userID, chatID, messageID string) error { + state, err := s.LoadSession(userID, chatID) + if err != nil { + return err + } + + for i := range state.Messages { + if state.Messages[i].ID == messageID { + state.Messages[i].Sent = true + break + } + } + + return s.SaveSession(state) +} + +// MarkAllMessagesAsSent marks all messages up to a certain index as sent +func (s *Storage) MarkAllMessagesAsSent(userID, chatID string) error { + state, err := s.LoadSession(userID, chatID) + if err != nil { + return err + } + + for i := range state.Messages { + state.Messages[i].Sent = true + } + + return s.SaveSession(state) +} + +// GetUnsentMessages returns all unsent messages for a session +func (s *Storage) GetUnsentMessages(userID, chatID string) ([]PersistedMessage, error) { + state, err := s.LoadSession(userID, chatID) + if err != nil { + return nil, err + } + + var unsent []PersistedMessage + for _, msg := range state.Messages { + if !msg.Sent { + unsent = append(unsent, msg) + } + } + + return unsent, nil +} + +// GetLastUnsentMessage returns the last unsent message for a session +func (s *Storage) GetLastUnsentMessage(userID, chatID string) (*PersistedMessage, error) { + unsent, err := s.GetUnsentMessages(userID, chatID) + if err != nil { + return nil, err + } + + if len(unsent) == 0 { + return nil, nil + } + + return &unsent[len(unsent)-1], nil +} + +// UpdateBackendConnectionStatus updates the backend connection status +func (s *Storage) UpdateBackendConnectionStatus(userID, chatID string, connected bool) error { + state, err := s.LoadSession(userID, chatID) + if err != nil { + return err + } + + state.BackendConnected = connected + return s.SaveSession(state) +} + +// IsSessionComplete checks if a session is complete (has final report or error) +func (s *Storage) IsSessionComplete(userID, chatID string) (bool, error) { + state, err := s.LoadSession(userID, chatID) + if err != nil { + return false, err + } + + return state.FinalReportReceived || state.ErrorOccurred, nil +} + +// CleanupOldSessions removes session files older than the specified duration +func (s *Storage) CleanupOldSessions(maxAge time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + + files, err := os.ReadDir(s.storagePath) + if err != nil { + return fmt.Errorf("failed to read storage directory: %w", err) + } + + now := time.Now() + for _, file := range files { + if file.IsDir() { + continue + } + + filePath := filepath.Join(s.storagePath, file.Name()) + info, err := file.Info() + if err != nil { + s.logger.WithComponent("deepr-storage").Error("failed to get file info", + slog.String("file", file.Name()), + slog.String("error", err.Error())) + continue + } + + if now.Sub(info.ModTime()) > maxAge { + if err := os.Remove(filePath); err != nil { + s.logger.WithComponent("deepr-storage").Error("failed to remove old session file", + slog.String("file", file.Name()), + slog.String("error", err.Error())) + } else { + s.logger.WithComponent("deepr-storage").Info("removed old session file", + slog.String("file", file.Name()), + slog.Duration("age", now.Sub(info.ModTime()))) + } + } + } + + return nil +} From 5f13e17040066372ac1d311052cf8c73f00cf0fd Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Wed, 8 Oct 2025 18:29:37 -0700 Subject: [PATCH 3/8] fix: default value of the deep research ws --- internal/deepr/service.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/internal/deepr/service.go b/internal/deepr/service.go index 0043e4c..a504530 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -215,9 +215,8 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket deepResearchHost := os.Getenv("DEEP_RESEARCH_WS") if deepResearchHost == "" { - log.Error("DEEP_RESEARCH_WS environment variable not set") - clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Deep research backend not configured"}`)) - return + deepResearchHost = "165.232.133.47:3031" + log.Info("DEEP_RESEARCH_WS environment variable not set, using default", slog.String("default", deepResearchHost)) } wsURL := url.URL{ From ff1e116bec25ec65712bfd89bd91c61dd264e543 Mon Sep 17 00:00:00 2001 From: Yaroslav Zhavoronkov Date: Wed, 8 Oct 2025 19:59:53 -0700 Subject: [PATCH 4/8] chore: add deepresearch-related deployment configuration --- deploy/enclaver.yaml | 4 ++++ deploy/envoy.yaml | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/deploy/enclaver.yaml b/deploy/enclaver.yaml index a22bd26..5581e1c 100644 --- a/deploy/enclaver.yaml +++ b/deploy/enclaver.yaml @@ -45,6 +45,8 @@ egress: # aws-0-us-east-2.pooler.supabase.com - 3.13.175.194 - 3.139.14.59 + # demo deepresearch endpoint (temporary) + - 165.232.133.47:3031 env: - APPSTORE_API_KEY_ID - APPSTORE_API_KEY_P8 @@ -58,6 +60,8 @@ env: - DB_CONN_MAX_LIFETIME_MINUTES - DB_MAX_IDLE_CONNS - DB_MAX_OPEN_CONNS +- DEEPR_STORAGE_PATH +- DEEP_RESEARCH_WS - DRIP_DAILY_MESSAGES - ENABLE_TELEGRAM_SERVER - EXA_API_KEY diff --git a/deploy/envoy.yaml b/deploy/envoy.yaml index 5e07c27..1644dc1 100644 --- a/deploy/envoy.yaml +++ b/deploy/envoy.yaml @@ -93,6 +93,24 @@ static_resources: upstream_wire_bytes_sent: '%UPSTREAM_WIRE_BYTES_SENT%' access_log_options: flush_access_log_on_connected: true + # demo deepresearch endpoint (temporary) + - name: egress_deepresearch + address: + socket_address: + address: 0.0.0.0 + port_value: 3031 + filter_chains: + - filters: + - name: envoy.filters.network.tcp_proxy + typed_config: + '@type': type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy + stat_prefix: tcp_proxy + cluster: enclaver_odyn_egress + tunneling_config: + hostname: '%DOWNSTREAM_DIRECT_REMOTE_ADDRESS_WITHOUT_PORT%:3031' + access_log: *access_log_tcp + access_log_options: + flush_access_log_on_connected: true - name: egress_postgresql address: socket_address: From ff5be2f8b96b57afa37b4bbfe53be0ac452911dc Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Thu, 9 Oct 2025 02:17:32 -0700 Subject: [PATCH 5/8] fix: coderrabbit bot potential bugs fixing --- internal/deepr/service.go | 97 +++++++++++++++------------ internal/deepr/storage.go | 137 ++++++++++++++++++++++++-------------- 2 files changed, 142 insertions(+), 92 deletions(-) diff --git a/internal/deepr/service.go b/internal/deepr/service.go index a504530..98f8862 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -3,6 +3,7 @@ package deepr import ( "context" "encoding/json" + "fmt" "log/slog" "net/url" "os" @@ -113,7 +114,7 @@ func (s *Service) checkAndTrackSubscription(ctx context.Context, clientConn *web if hasUsed { log.Info("freemium user has already used their free deep research", slog.String("user_id", userID)) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "You have already used your free deep research. Please upgrade to Pro for unlimited access.", "error_code": "FREE_LIMIT_REACHED"}`)) - return err + return fmt.Errorf("freemium quota exhausted for user %s", userID) } if err := s.firebaseClient.MarkFreeDeepResearchUsed(ctx, userID); err != nil { @@ -137,18 +138,15 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. slog.String("chat_id", chatID), slog.String("client_id", clientID)) - // Add client to session manager - s.sessionManager.AddClientConnection(userID, chatID, clientID, clientConn) - defer s.sessionManager.RemoveClientConnection(userID, chatID, clientID) - - // Check if session is complete + // Check if session is complete and replay unsent messages BEFORE adding client to session manager + // This prevents concurrent writes: backend broadcast won't know about this client during replay if s.storage != nil { isComplete, err := s.storage.IsSessionComplete(userID, chatID) if err != nil { log.Error("failed to check session completion status", slog.String("error", err.Error())) } - // Send unsent messages + // Send unsent messages before registering the connection unsent, err := s.storage.GetUnsentMessages(userID, chatID) if err != nil { log.Error("failed to get unsent messages", slog.String("error", err.Error())) @@ -174,6 +172,10 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. } } + // Now that replay is complete, add client to session manager for future broadcasts + s.sessionManager.AddClientConnection(userID, chatID, clientID, clientConn) + defer s.sessionManager.RemoveClientConnection(userID, chatID, clientID) + // Listen for new messages from backend (they'll be broadcast to all clients) done := make(chan struct{}) @@ -209,6 +211,38 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. <-done } +// handleClientMessages handles forwarding messages from a client to the backend +func (s *Service) handleClientMessages(ctx context.Context, clientConn *websocket.Conn, userID, chatID, clientID string) { + log := s.logger.WithContext(ctx).WithComponent("deepr") + + for { + select { + case <-ctx.Done(): + return + default: + _, message, err := clientConn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Error("error reading from client", slog.String("error", err.Error())) + } + s.sessionManager.RemoveClientConnection(userID, chatID, clientID) + return + } + + log.Info("received message from client", slog.String("message", string(message))) + + // Forward to backend if session exists + if session, exists := s.sessionManager.GetSession(userID, chatID); exists && session.BackendConn != nil { + if err := session.BackendConn.WriteMessage(websocket.TextMessage, message); err != nil { + log.Error("error forwarding message to backend", slog.String("error", err.Error())) + return + } + log.Info("message forwarded to deep research backend") + } + } + } +} + // handleNewConnection creates a new backend connection and manages message flow func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket.Conn, userID, chatID, clientID string) { log := s.logger.WithContext(ctx).WithComponent("deepr") @@ -249,8 +283,8 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket }() } - // Create session context - sessionCtx, cancel := context.WithCancel(ctx) + // Create session context that is independent of any single client + sessionCtx, cancel := context.WithCancel(context.Background()) defer cancel() // Create and register session @@ -260,43 +294,14 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket // Add initial client s.sessionManager.AddClientConnection(userID, chatID, clientID, clientConn) - done := make(chan struct{}) + // Handle messages from this client to backend in a separate goroutine + go s.handleClientMessages(ctx, clientConn, userID, chatID, clientID) - // Handle messages from client to backend - go func() { - defer close(done) - for { - select { - case <-sessionCtx.Done(): - return - default: - _, message, err := clientConn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Error("error reading from client", slog.String("error", err.Error())) - } - s.sessionManager.RemoveClientConnection(userID, chatID, clientID) - return - } - - log.Info("received message from client", slog.String("message", string(message))) - - if err := serverConn.WriteMessage(websocket.TextMessage, message); err != nil { - log.Error("error writing to server", slog.String("error", err.Error())) - return - } - - log.Info("message forwarded to deep research backend") - } - } - }() - - // Handle messages from backend to clients + // Handle messages from backend to clients - this loop runs until backend disconnects for { select { - case <-done: - return case <-sessionCtx.Done(): + log.Info("session context cancelled, stopping backend reader") return default: _, message, err := serverConn.ReadMessage() @@ -304,6 +309,7 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Error("error reading from backend", slog.String("error", err.Error())) } + log.Info("backend connection closed, session will be removed") return } @@ -333,6 +339,13 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket slog.Bool("sent", messageSent), slog.String("type", messageType)) } + + // Check if session is complete + if msg.FinalReport != "" || msg.Type == "error" || msg.Error != "" { + log.Info("session complete, backend will be closed after final message broadcast") + // Continue to allow reconnecting clients to receive this final message + // The backend connection will close naturally or via timeout + } } else { // No storage, just broadcast s.sessionManager.BroadcastToClients(userID, chatID, message) diff --git a/internal/deepr/storage.go b/internal/deepr/storage.go index 5309085..39fe4a7 100644 --- a/internal/deepr/storage.go +++ b/internal/deepr/storage.go @@ -76,6 +76,40 @@ func (s *Storage) SaveSession(state *SessionState) error { s.mu.Lock() defer s.mu.Unlock() + return s.saveSessionUnsafe(state) +} + +// loadSessionUnsafe loads a session without acquiring locks (internal use only) +func (s *Storage) loadSessionUnsafe(userID, chatID string) (*SessionState, error) { + filePath := s.getSessionFilePath(userID, chatID) + + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + // No existing session, return new one + return &SessionState{ + UserID: userID, + ChatID: chatID, + Messages: []PersistedMessage{}, + BackendConnected: false, + LastActivity: time.Now(), + FinalReportReceived: false, + ErrorOccurred: false, + }, nil + } + return nil, fmt.Errorf("failed to read session file: %w", err) + } + + var state SessionState + if err := json.Unmarshal(data, &state); err != nil { + return nil, fmt.Errorf("failed to unmarshal session: %w", err) + } + + return &state, nil +} + +// saveSessionUnsafe saves a session without acquiring locks (internal use only) +func (s *Storage) saveSessionUnsafe(state *SessionState) error { filePath := s.getSessionFilePath(state.UserID, state.ChatID) state.LastActivity = time.Now() @@ -92,68 +126,74 @@ func (s *Storage) SaveSession(state *SessionState) error { return nil } -// AddMessage adds a new message to the session -func (s *Storage) AddMessage(userID, chatID, message string, sent bool, messageType string) error { - state, err := s.LoadSession(userID, chatID) +// modifySession executes a mutation function on a session while holding the write lock +func (s *Storage) modifySession(userID, chatID string, mutate func(*SessionState) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.loadSessionUnsafe(userID, chatID) if err != nil { return err } - persistedMsg := PersistedMessage{ - ID: uuid.New().String(), - UserID: userID, - ChatID: chatID, - Message: message, - Sent: sent, - Timestamp: time.Now(), - MessageType: messageType, + if err := mutate(state); err != nil { + return err } - state.Messages = append(state.Messages, persistedMsg) + return s.saveSessionUnsafe(state) +} - // Check if this is a final report or error - var msg Message - if err := json.Unmarshal([]byte(message), &msg); err == nil { - if msg.FinalReport != "" { - state.FinalReportReceived = true +// AddMessage adds a new message to the session +func (s *Storage) AddMessage(userID, chatID, message string, sent bool, messageType string) error { + return s.modifySession(userID, chatID, func(state *SessionState) error { + persistedMsg := PersistedMessage{ + ID: uuid.New().String(), + UserID: userID, + ChatID: chatID, + Message: message, + Sent: sent, + Timestamp: time.Now(), + MessageType: messageType, } - if msg.Type == "error" || msg.Error != "" { - state.ErrorOccurred = true + + state.Messages = append(state.Messages, persistedMsg) + + // Check if this is a final report or error + var msg Message + if err := json.Unmarshal([]byte(message), &msg); err == nil { + if msg.FinalReport != "" { + state.FinalReportReceived = true + } + if msg.Type == "error" || msg.Error != "" { + state.ErrorOccurred = true + } } - } - return s.SaveSession(state) + return nil + }) } // MarkMessageAsSent marks a specific message as sent func (s *Storage) MarkMessageAsSent(userID, chatID, messageID string) error { - state, err := s.LoadSession(userID, chatID) - if err != nil { - return err - } - - for i := range state.Messages { - if state.Messages[i].ID == messageID { - state.Messages[i].Sent = true - break + return s.modifySession(userID, chatID, func(state *SessionState) error { + for i := range state.Messages { + if state.Messages[i].ID == messageID { + state.Messages[i].Sent = true + break + } } - } - - return s.SaveSession(state) + return nil + }) } // MarkAllMessagesAsSent marks all messages up to a certain index as sent func (s *Storage) MarkAllMessagesAsSent(userID, chatID string) error { - state, err := s.LoadSession(userID, chatID) - if err != nil { - return err - } - - for i := range state.Messages { - state.Messages[i].Sent = true - } - - return s.SaveSession(state) + return s.modifySession(userID, chatID, func(state *SessionState) error { + for i := range state.Messages { + state.Messages[i].Sent = true + } + return nil + }) } // GetUnsentMessages returns all unsent messages for a session @@ -189,13 +229,10 @@ func (s *Storage) GetLastUnsentMessage(userID, chatID string) (*PersistedMessage // UpdateBackendConnectionStatus updates the backend connection status func (s *Storage) UpdateBackendConnectionStatus(userID, chatID string, connected bool) error { - state, err := s.LoadSession(userID, chatID) - if err != nil { - return err - } - - state.BackendConnected = connected - return s.SaveSession(state) + return s.modifySession(userID, chatID, func(state *SessionState) error { + state.BackendConnected = connected + return nil + }) } // IsSessionComplete checks if a session is complete (has final report or error) From 41d6c41e94fba9bd32b0c764ec18caf3b6b64b75 Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Thu, 9 Oct 2025 02:26:54 -0700 Subject: [PATCH 6/8] fix: coderabbit potential bug fixes --- internal/deepr/service.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/deepr/service.go b/internal/deepr/service.go index 98f8862..6cc78da 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -157,6 +157,7 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. for _, msg := range unsent { if err := clientConn.WriteMessage(websocket.TextMessage, []byte(msg.Message)); err != nil { log.Error("failed to send unsent message", slog.String("error", err.Error())) + clientConn.Close() return } // Mark as sent @@ -168,6 +169,7 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. if isComplete { log.Info("session is complete, no more messages expected") + clientConn.Close() return } } From 1aac33295e77c4c1e9e8618a0c88f41d90a8ba9c Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Thu, 9 Oct 2025 02:37:39 -0700 Subject: [PATCH 7/8] fix: Serialize writes to the backend websocket. --- internal/deepr/service.go | 22 ++++++++---------- internal/deepr/session_manager.go | 38 +++++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/internal/deepr/service.go b/internal/deepr/service.go index 6cc78da..8a7b568 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -199,12 +199,10 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. log.Info("received message from reconnected client", slog.String("message", string(message))) - // Forward to backend if session exists - if session, exists := s.sessionManager.GetSession(userID, chatID); exists && session.BackendConn != nil { - if err := session.BackendConn.WriteMessage(websocket.TextMessage, message); err != nil { - log.Error("error forwarding message to backend", slog.String("error", err.Error())) - return - } + // Forward to backend using synchronized write + if err := s.sessionManager.WriteToBackend(userID, chatID, websocket.TextMessage, message); err != nil { + log.Error("error forwarding message to backend", slog.String("error", err.Error())) + return } } } @@ -233,14 +231,12 @@ func (s *Service) handleClientMessages(ctx context.Context, clientConn *websocke log.Info("received message from client", slog.String("message", string(message))) - // Forward to backend if session exists - if session, exists := s.sessionManager.GetSession(userID, chatID); exists && session.BackendConn != nil { - if err := session.BackendConn.WriteMessage(websocket.TextMessage, message); err != nil { - log.Error("error forwarding message to backend", slog.String("error", err.Error())) - return - } - log.Info("message forwarded to deep research backend") + // Forward to backend using synchronized write + if err := s.sessionManager.WriteToBackend(userID, chatID, websocket.TextMessage, message); err != nil { + log.Error("error forwarding message to backend", slog.String("error", err.Error())) + return } + log.Info("message forwarded to deep research backend") } } } diff --git a/internal/deepr/session_manager.go b/internal/deepr/session_manager.go index f34bc0b..e25bfcd 100644 --- a/internal/deepr/session_manager.go +++ b/internal/deepr/session_manager.go @@ -11,13 +11,14 @@ import ( // ActiveSession represents an active backend connection type ActiveSession struct { - UserID string - ChatID string - BackendConn *websocket.Conn - Context context.Context - CancelFunc context.CancelFunc - mu sync.RWMutex - clientConns map[string]*websocket.Conn // Map of client connection IDs + UserID string + ChatID string + BackendConn *websocket.Conn + Context context.Context + CancelFunc context.CancelFunc + mu sync.RWMutex // Protects clientConns map + backendWriteMu sync.Mutex // Serializes writes to backend websocket + clientConns map[string]*websocket.Conn // Map of client connection IDs } // SessionManager manages active backend connections @@ -209,3 +210,26 @@ func (sm *SessionManager) HasActiveBackend(userID, chatID string) bool { _, exists := sm.sessions[key] return exists } + +// WriteToBackend sends a message to the backend websocket with proper synchronization +// This method ensures only one goroutine writes to the backend at a time +func (sm *SessionManager) WriteToBackend(userID, chatID string, messageType int, message []byte) error { + sm.mu.RLock() + key := sm.getSessionKey(userID, chatID) + session, exists := sm.sessions[key] + sm.mu.RUnlock() + + if !exists { + return nil // No active session + } + + // Serialize writes to backend websocket + session.backendWriteMu.Lock() + defer session.backendWriteMu.Unlock() + + if session.BackendConn == nil { + return nil // Backend connection closed + } + + return session.BackendConn.WriteMessage(messageType, message) +} From f3a426e40c0ec6f3841142f169c3336397c5d5d4 Mon Sep 17 00:00:00 2001 From: David Mayboroda Date: Thu, 9 Oct 2025 02:51:47 -0700 Subject: [PATCH 8/8] fix: Complete session doesn't close backend connection. --- internal/deepr/service.go | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/internal/deepr/service.go b/internal/deepr/service.go index 8a7b568..e892b44 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "path/filepath" + "time" "github.com/eternisai/enchanted-proxy/internal/auth" "github.com/eternisai/enchanted-proxy/internal/logger" @@ -259,7 +260,11 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket log.Info("connecting to deep research server", slog.String("url", wsURL.String())) - serverConn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + // Create dialer with timeout to prevent indefinite hangs + dialer := *websocket.DefaultDialer + dialer.HandshakeTimeout = 30 * time.Second + + serverConn, _, err := dialer.Dial(wsURL.String(), nil) if err != nil { log.Error("failed to connect to deep research server", slog.String("error", err.Error())) clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to connect to deep research backend"}`)) @@ -281,8 +286,9 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket }() } - // Create session context that is independent of any single client - sessionCtx, cancel := context.WithCancel(context.Background()) + // Create session context derived from incoming context but independent of any single client + // This ensures cleanup on server shutdown while allowing session to outlive individual clients + sessionCtx, cancel := context.WithCancel(ctx) defer cancel() // Create and register session @@ -340,13 +346,23 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket // Check if session is complete if msg.FinalReport != "" || msg.Type == "error" || msg.Error != "" { - log.Info("session complete, backend will be closed after final message broadcast") - // Continue to allow reconnecting clients to receive this final message - // The backend connection will close naturally or via timeout + log.Info("session complete, closing backend connection and cleaning up") + // Final message has been stored and broadcast, now clean up + // This cancels the session context and exits the loop + // Defers will close backend connection and remove session from manager + cancel() + return } } else { // No storage, just broadcast s.sessionManager.BroadcastToClients(userID, chatID, message) + + // Check if session is complete even without storage + if msg.FinalReport != "" || msg.Type == "error" || msg.Error != "" { + log.Info("session complete (no storage), closing backend connection and cleaning up") + cancel() + return + } } } }