diff --git a/.gitignore b/.gitignore index ec91195..b7e8e7f 100644 --- a/.gitignore +++ b/.gitignore @@ -57,4 +57,7 @@ temp/ tests/ # Scripts -__pycache__/ \ No newline at end of file +__pycache__/ + +# Deep research session files (deprecated, now using database) +deepr_sessions/ \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go index c633596..4bb8b86 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -129,6 +129,9 @@ func main() { mcpService := mcp.NewService() searchService := search.NewService(logger.WithComponent("search")) + // Initialize deep research storage + deeprStorage := deepr.NewDBStorage(logger.WithComponent("deepr-storage"), db.DB) + // Initialize handlers oauthHandler := oauth.NewHandler(oauthService, logger.WithComponent("oauth")) composioHandler := composio.NewHandler(composioService, logger.WithComponent("composio")) @@ -190,6 +193,7 @@ func main() { iapHandler: iapHandler, mcpHandler: mcpHandler, searchHandler: searchHandler, + deeprStorage: deeprStorage, }) // Initialize GraphQL server for Telegram @@ -290,6 +294,7 @@ type restServerInput struct { iapHandler *iap.Handler mcpHandler *mcp.Handler searchHandler *search.Handler + deeprStorage deepr.MessageStorage } func setupRESTServer(input restServerInput) *gin.Engine { @@ -364,7 +369,7 @@ func setupRESTServer(input restServerInput) *gin.Engine { api.POST("/exa/search", input.searchHandler.PostExaSearchHandler) // POST /api/v1/exa/search (Exa AI) // Deep Research WebSocket endpoint (protected) - api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient)) // WebSocket proxy for deep research + api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient, input.deeprStorage)) // WebSocket proxy for deep research } // Protected proxy routes diff --git a/internal/auth/firebase_client.go b/internal/auth/firebase_client.go index 5dc1406..949acca 100644 --- a/internal/auth/firebase_client.go +++ b/internal/auth/firebase_client.go @@ -176,3 +176,127 @@ func (f *FirebaseClient) SaveDeepResearchCompletion(ctx context.Context, userID, return nil } + +// DeepResearchSessionState represents the state of a deep research session +type DeepResearchSessionState struct { + UserID string `firestore:"user_id"` + ChatID string `firestore:"chat_id"` + State string `firestore:"state"` // in_progress, clarify, error, complete + CreatedAt time.Time `firestore:"created_at"` + UpdatedAt time.Time `firestore:"updated_at"` + CompletedAt time.Time `firestore:"completed_at,omitempty"` +} + +// GetSessionState retrieves the current state of a deep research session +func (f *FirebaseClient) GetSessionState(ctx context.Context, userID, chatID string) (*DeepResearchSessionState, error) { + // Use underscore as separator since forward slash is not allowed in Firestore document IDs + sessionID := fmt.Sprintf("%s__%s", userID, chatID) + docRef := f.firestoreClient.Collection("deep_research_sessions").Doc(sessionID) + doc, err := docRef.Get(ctx) + + if err != nil { + // If document doesn't exist, session hasn't been created yet + if status.Code(err) == codes.NotFound { + return nil, nil + } + return nil, fmt.Errorf("failed to get session state: %w", err) + } + + var state DeepResearchSessionState + if err := doc.DataTo(&state); err != nil { + return nil, fmt.Errorf("failed to parse session state: %w", err) + } + + return &state, nil +} + +// UpdateSessionState updates the state of a deep research session +func (f *FirebaseClient) UpdateSessionState(ctx context.Context, userID, chatID, state string) error { + // Use underscore as separator since forward slash is not allowed in Firestore document IDs + sessionID := fmt.Sprintf("%s__%s", userID, chatID) + docRef := f.firestoreClient.Collection("deep_research_sessions").Doc(sessionID) + now := time.Now() + + // Check if document exists + _, err := docRef.Get(ctx) + + if err != nil { + // Document doesn't exist, create new one + if status.Code(err) == codes.NotFound { + sessionState := DeepResearchSessionState{ + UserID: userID, + ChatID: chatID, + State: state, + CreatedAt: now, + UpdatedAt: now, + } + + // Set completed_at if state is complete + if state == "complete" { + sessionState.CompletedAt = now + } + + _, err := docRef.Set(ctx, sessionState) + if err != nil { + return fmt.Errorf("failed to create session state: %w", err) + } + return nil + } + return fmt.Errorf("failed to get session state: %w", err) + } + + // Document exists, update it + updateData := map[string]interface{}{ + "state": state, + "updated_at": now, + } + + // Set completed_at when state is complete + if state == "complete" { + updateData["completed_at"] = now + } + + _, err = docRef.Set(ctx, updateData, firestore.MergeAll) + if err != nil { + return fmt.Errorf("failed to update session state: %w", err) + } + + return nil +} + +// GetActiveSessionsForUser retrieves all active (non-complete, non-error) sessions for a user +func (f *FirebaseClient) GetActiveSessionsForUser(ctx context.Context, userID string) ([]DeepResearchSessionState, error) { + query := f.firestoreClient.Collection("deep_research_sessions"). + Where("user_id", "==", userID). + Where("state", "in", []string{"in_progress", "clarify"}) + + docs, err := query.Documents(ctx).GetAll() + if err != nil { + return nil, fmt.Errorf("failed to get active sessions: %w", err) + } + + var sessions []DeepResearchSessionState + for _, doc := range docs { + var session DeepResearchSessionState + if err := doc.DataTo(&session); err != nil { + return nil, fmt.Errorf("failed to parse session: %w", err) + } + sessions = append(sessions, session) + } + + return sessions, nil +} + +// GetCompletedSessionCountForUser returns the number of completed deep research sessions for a user +func (f *FirebaseClient) GetCompletedSessionCountForUser(ctx context.Context, userID string) (int, error) { + query := f.firestoreClient.Collection("deep_research_sessions"). + Where("user_id", "==", userID). + Where("state", "==", "complete") + + docs, err := query.Documents(ctx).GetAll() + if err != nil { + return 0, fmt.Errorf("failed to get completed sessions count: %w", err) + } + + return len(docs), nil +} diff --git a/internal/deepr/db_storage.go b/internal/deepr/db_storage.go new file mode 100644 index 0000000..b5ea8a4 --- /dev/null +++ b/internal/deepr/db_storage.go @@ -0,0 +1,270 @@ +package deepr + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "time" + + "github.com/eternisai/enchanted-proxy/internal/logger" + "github.com/google/uuid" +) + +// DBStorage handles persistence of deep research messages to PostgreSQL +type DBStorage struct { + logger *logger.Logger + db *sql.DB +} + +// NewDBStorage creates a new database storage instance +func NewDBStorage(logger *logger.Logger, db *sql.DB) *DBStorage { + logger.WithComponent("deepr-db-storage").Info("database storage initialized") + + return &DBStorage{ + logger: logger, + db: db, + } +} + +// AddMessage adds a new message to the database +func (s *DBStorage) AddMessage(userID, chatID, message string, sent bool, messageType string) error { + log := s.logger.WithComponent("deepr-db-storage") + + messageID := uuid.New().String() + // Use double underscore as separator to match Firestore format + sessionID := fmt.Sprintf("%s__%s", userID, chatID) + now := time.Now().UTC() + + sentAt := sql.NullTime{} + if sent { + sentAt = sql.NullTime{ + Time: now, + Valid: true, + } + } + + query := ` + INSERT INTO deep_research_messages (id, user_id, chat_id, session_id, message, message_type, sent, created_at, sent_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ` + + _, err := s.db.Exec(query, messageID, userID, chatID, sessionID, message, messageType, sent, now, sentAt) + if err != nil { + log.Error("failed to add message to database", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_type", messageType), + slog.String("error", err.Error())) + return fmt.Errorf("failed to add message: %w", err) + } + + log.Debug("message added to database", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_id", messageID), + slog.String("message_type", messageType), + slog.Bool("sent", sent)) + + return nil +} + +// GetUnsentMessages retrieves all unsent messages for a session +func (s *DBStorage) GetUnsentMessages(userID, chatID string) ([]PersistedMessage, error) { + log := s.logger.WithComponent("deepr-db-storage") + + // Use double underscore as separator to match Firestore format + sessionID := fmt.Sprintf("%s__%s", userID, chatID) + + query := ` + SELECT id, user_id, chat_id, message, message_type, sent, created_at + FROM deep_research_messages + WHERE session_id = $1 AND sent = FALSE + ORDER BY created_at ASC + ` + + rows, err := s.db.Query(query, sessionID) + if err != nil { + log.Error("failed to query unsent messages", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + return nil, fmt.Errorf("failed to query unsent messages: %w", err) + } + defer rows.Close() + + var messages []PersistedMessage + for rows.Next() { + var msg PersistedMessage + err := rows.Scan(&msg.ID, &msg.UserID, &msg.ChatID, &msg.Message, &msg.MessageType, &msg.Sent, &msg.Timestamp) + if err != nil { + log.Error("failed to scan message row", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + return nil, fmt.Errorf("failed to scan message: %w", err) + } + messages = append(messages, msg) + } + + if err = rows.Err(); err != nil { + log.Error("error iterating message rows", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + return nil, fmt.Errorf("error iterating messages: %w", err) + } + + log.Info("retrieved unsent messages", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Int("unsent_count", len(messages))) + + return messages, nil +} + +// MarkMessageAsSent marks a specific message as sent +func (s *DBStorage) MarkMessageAsSent(userID, chatID, messageID string) error { + log := s.logger.WithComponent("deepr-db-storage") + + query := ` + UPDATE deep_research_messages + SET sent = TRUE, sent_at = NOW() + WHERE id = $1 + ` + + result, err := s.db.Exec(query, messageID) + if err != nil { + log.Error("failed to mark message as sent", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_id", messageID), + slog.String("error", err.Error())) + return fmt.Errorf("failed to mark message as sent: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + log.Warn("failed to get rows affected", + slog.String("message_id", messageID), + slog.String("error", err.Error())) + } else { + log.Debug("message marked as sent", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_id", messageID), + slog.Int64("rows_affected", rowsAffected)) + } + + return nil +} + +// MarkAllMessagesAsSent marks all unsent messages for a session as sent +func (s *DBStorage) MarkAllMessagesAsSent(userID, chatID string) error { + log := s.logger.WithComponent("deepr-db-storage") + + // Use double underscore as separator to match Firestore format + sessionID := fmt.Sprintf("%s__%s", userID, chatID) + + query := ` + UPDATE deep_research_messages + SET sent = TRUE, sent_at = NOW() + WHERE session_id = $1 AND sent = FALSE + ` + + result, err := s.db.Exec(query, sessionID) + if err != nil { + log.Error("failed to mark all messages as sent", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + return fmt.Errorf("failed to mark all messages as sent: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + log.Warn("failed to get rows affected", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + } else { + log.Info("all messages marked as sent", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Int64("rows_affected", rowsAffected)) + } + + return nil +} + +// UpdateBackendConnectionStatus is a no-op for database storage +// Connection status is tracked via session state in Firebase +func (s *DBStorage) UpdateBackendConnectionStatus(userID, chatID string, connected bool) error { + // No-op: backend connection status is tracked via Firebase session state + return nil +} + +// IsSessionComplete checks if a session has completed (has research_complete or error message) +func (s *DBStorage) IsSessionComplete(userID, chatID string) (bool, error) { + log := s.logger.WithComponent("deepr-db-storage") + + // Use double underscore as separator to match Firestore format + sessionID := fmt.Sprintf("%s__%s", userID, chatID) + + query := ` + SELECT COUNT(*) > 0 as is_complete + FROM deep_research_messages + WHERE session_id = $1 AND message_type IN ('research_complete', 'error') + LIMIT 1 + ` + + var isComplete bool + err := s.db.QueryRow(query, sessionID).Scan(&isComplete) + if err != nil { + log.Error("failed to check session completion", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + return false, fmt.Errorf("failed to check session completion: %w", err) + } + + log.Debug("session completion status checked", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Bool("is_complete", isComplete)) + + return isComplete, nil +} + +// CleanupOldSessions removes messages older than the specified duration +func (s *DBStorage) CleanupOldSessions(ctx context.Context, maxAge time.Duration) error { + log := s.logger.WithComponent("deepr-db-storage") + + cutoffTime := time.Now().Add(-maxAge) + + query := ` + DELETE FROM deep_research_messages + WHERE created_at < $1 + ` + + result, err := s.db.ExecContext(ctx, query, cutoffTime) + if err != nil { + log.Error("failed to cleanup old sessions", + slog.Duration("max_age", maxAge), + slog.Time("cutoff_time", cutoffTime), + slog.String("error", err.Error())) + return fmt.Errorf("failed to cleanup old sessions: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + log.Warn("failed to get rows affected for cleanup", + slog.String("error", err.Error())) + } else { + log.Info("old sessions cleaned up", + slog.Int64("messages_deleted", rowsAffected), + slog.Duration("max_age", maxAge)) + } + + return nil +} diff --git a/internal/deepr/handlers.go b/internal/deepr/handlers.go index cab04c2..bab67ec 100644 --- a/internal/deepr/handlers.go +++ b/internal/deepr/handlers.go @@ -18,7 +18,7 @@ var upgrader = websocket.Upgrader{ } // DeepResearchHandler handles WebSocket connections for deep research streaming -func DeepResearchHandler(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient) gin.HandlerFunc { +func DeepResearchHandler(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient, storage MessageStorage) gin.HandlerFunc { return func(c *gin.Context) { log := logger.WithContext(c.Request.Context()).WithComponent("deepr") @@ -74,8 +74,8 @@ func DeepResearchHandler(logger *logger.Logger, trackingService *request_trackin slog.String("chat_id", chatID), slog.String("remote_addr", c.Request.RemoteAddr)) - // Create service instance - service := NewService(logger, trackingService, firebaseClient) + // Create service instance with database storage + service := NewService(logger, trackingService, firebaseClient, storage) // Handle the WebSocket connection service.HandleConnection(c.Request.Context(), conn, userID, chatID) diff --git a/internal/deepr/models.go b/internal/deepr/models.go index 9e6d621..edd699b 100644 --- a/internal/deepr/models.go +++ b/internal/deepr/models.go @@ -24,7 +24,7 @@ type Response struct { Status string `json:"status,omitempty"` } -// PersistedMessage represents a message stored to disk +// PersistedMessage represents a message stored in the database type PersistedMessage struct { ID string `json:"id"` UserID string `json:"user_id"` @@ -34,14 +34,3 @@ type PersistedMessage struct { Timestamp time.Time `json:"timestamp"` MessageType string `json:"message_type"` // "status", "error", "final", etc. } - -// SessionState represents the state of a deep research session -type SessionState struct { - UserID string `json:"user_id"` - ChatID string `json:"chat_id"` - Messages []PersistedMessage `json:"messages"` - BackendConnected bool `json:"backend_connected"` - LastActivity time.Time `json:"last_activity"` - FinalReportReceived bool `json:"final_report_received"` - ErrorOccurred bool `json:"error_occurred"` -} diff --git a/internal/deepr/service.go b/internal/deepr/service.go index 3b83343..7e0778a 100644 --- a/internal/deepr/service.go +++ b/internal/deepr/service.go @@ -7,7 +7,6 @@ import ( "log/slog" "net/url" "os" - "path/filepath" "time" "github.com/eternisai/enchanted-proxy/internal/auth" @@ -22,24 +21,191 @@ type Service struct { logger *logger.Logger trackingService *request_tracking.Service firebaseClient *auth.FirebaseClient - storage *Storage + storage MessageStorage sessionManager *SessionManager } -// NewService creates a new deep research service -func NewService(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient) *Service { - // Get storage path from environment or use default - storagePath := os.Getenv("DEEPR_STORAGE_PATH") - if storagePath == "" { - storagePath = filepath.Join(".", "deepr_sessions") +// mapEventTypeToState maps event types from deep research server to session states +func mapEventTypeToState(eventType string) string { + switch eventType { + case "clarification_needed": + return "clarify" + case "error": + return "error" + case "research_complete": + return "complete" + default: + // All other events (research_progress, etc.) map to in_progress + return "in_progress" + } +} + +// canForwardMessage checks if a message from the client should be forwarded to the backend +// based on the current session state. Messages can only be forwarded when state is 'clarify' or 'error' +func (s *Service) canForwardMessage(ctx context.Context, userID, chatID string) (bool, string, error) { + log := s.logger.WithContext(ctx).WithComponent("deepr") + + sessionState, err := s.firebaseClient.GetSessionState(ctx, userID, chatID) + if err != nil { + log.Error("failed to get session state for message forwarding check", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + return false, "", fmt.Errorf("failed to get session state: %w", err) + } + + // If no session state exists yet, allow forwarding (initial message) + if sessionState == nil { + log.Debug("no session state found, allowing initial message", + slog.String("user_id", userID), + slog.String("chat_id", chatID)) + return true, "", nil + } + + // Only allow message forwarding when state is 'clarify' or 'error' + canForward := sessionState.State == "clarify" || sessionState.State == "error" + + log.Info("message forwarding check", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("session_state", sessionState.State), + slog.Bool("can_forward", canForward)) + + if !canForward { + return false, sessionState.State, nil + } + + return true, sessionState.State, nil +} + +// validateFreemiumAccess checks if a freemium user can start or continue a deep research session +// Returns error if user is not allowed to proceed +// Premium users can have multiple sessions but still cannot write during 'in_progress' state +func (s *Service) validateFreemiumAccess(ctx context.Context, clientConn *websocket.Conn, userID, chatID string, isReconnection bool) error { + log := s.logger.WithContext(ctx).WithComponent("deepr") + + // Check if user has active pro subscription + hasActivePro, _, err := s.trackingService.HasActivePro(ctx, userID) + if err != nil { + log.Error("failed to check subscription status", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify subscription status"}`)) + return fmt.Errorf("failed to check subscription status: %w", err) + } + + // Premium users can have multiple sessions - no restrictions on session creation + if hasActivePro { + log.Info("premium user, multiple sessions allowed", + slog.String("user_id", userID), + slog.String("chat_id", chatID)) + return nil + } + + // Freemium user - check restrictions + log.Info("freemium user detected, checking access restrictions", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Bool("is_reconnection", isReconnection)) + + // Get current session state + sessionState, err := s.firebaseClient.GetSessionState(ctx, userID, chatID) + if err != nil { + log.Error("failed to get session state", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify session state"}`)) + return fmt.Errorf("failed to get session state: %w", err) + } + + // If this is a reconnection or existing session + if sessionState != nil { + log.Info("existing session found", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("session_state", sessionState.State)) + + // Allow reconnection/continuation if state is 'clarify' or 'in_progress' + if sessionState.State == "clarify" || sessionState.State == "in_progress" { + log.Info("freemium user allowed to continue existing session", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("session_state", sessionState.State)) + return nil + } + + // If state is 'complete' or 'error', check if user has other completed sessions + if sessionState.State == "complete" || sessionState.State == "error" { + completedCount, err := s.firebaseClient.GetCompletedSessionCountForUser(ctx, userID) + if err != nil { + log.Error("failed to get completed session count", + slog.String("user_id", userID), + slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify usage status"}`)) + return fmt.Errorf("failed to get completed session count: %w", err) + } + + if completedCount >= 1 { + log.Warn("freemium quota exhausted - user has completed session", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Int("completed_count", completedCount)) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "You have already used your free deep research. Please upgrade to Pro for unlimited access.", "error_code": "FREE_LIMIT_REACHED"}`)) + return fmt.Errorf("freemium quota exhausted for user %s", userID) + } + } + + return nil + } + + // New session - check if user already has completed research + completedCount, err := s.firebaseClient.GetCompletedSessionCountForUser(ctx, userID) + if err != nil { + log.Error("failed to get completed session count", + slog.String("user_id", userID), + slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify usage status"}`)) + return fmt.Errorf("failed to get completed session count: %w", err) + } + + if completedCount >= 1 { + log.Warn("freemium quota exhausted - user already has completed research", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Int("completed_count", completedCount)) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "You have already used your free deep research. Please upgrade to Pro for unlimited access.", "error_code": "FREE_LIMIT_REACHED"}`)) + return fmt.Errorf("freemium quota exhausted for user %s", userID) } - storage, err := NewStorage(logger, storagePath) + // Check if user has any active (in_progress or clarify) sessions + activeSessions, err := s.firebaseClient.GetActiveSessionsForUser(ctx, userID) if err != nil { - logger.WithComponent("deepr").Error("failed to create storage, using in-memory only", + log.Error("failed to get active sessions", + slog.String("user_id", userID), slog.String("error", err.Error())) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify active sessions"}`)) + return fmt.Errorf("failed to get active sessions: %w", err) + } + + if len(activeSessions) > 0 { + log.Warn("freemium user already has an active session", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Int("active_sessions_count", len(activeSessions))) + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "You already have an active deep research session. Please complete or cancel it before starting a new one.", "error_code": "ACTIVE_SESSION_EXISTS"}`)) + return fmt.Errorf("freemium user %s already has active session", userID) } + log.Info("freemium user allowed to start new session", + slog.String("user_id", userID), + slog.String("chat_id", chatID)) + return nil +} + +// NewService creates a new deep research service with database storage +func NewService(logger *logger.Logger, trackingService *request_tracking.Service, firebaseClient *auth.FirebaseClient, storage MessageStorage) *Service { return &Service{ logger: logger, trackingService: trackingService, @@ -81,17 +247,15 @@ func (s *Service) HandleConnection(ctx context.Context, clientConn *websocket.Co slog.String("client_id", clientID), slog.Bool("is_reconnection", false)) - // New connection - perform subscription checks - // LIMIT CHECK DISABLED: Allow all users to use deep research - // if err := s.checkAndTrackSubscription(ctx, clientConn, userID); err != nil { - // log.Error("subscription check failed", - // slog.String("user_id", userID), - // slog.String("chat_id", chatID), - // slog.String("error", err.Error()), - // slog.Duration("duration", time.Since(startTime))) - // clientConn.Close() - // return - // } + // Validate freemium access for new connections + if err := s.validateFreemiumAccess(ctx, clientConn, userID, chatID, false); err != nil { + log.Error("freemium validation failed", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + clientConn.Close() + return + } // Create new backend connection s.handleNewConnection(ctx, clientConn, userID, chatID, clientID) @@ -192,6 +356,16 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. slog.String("chat_id", chatID), slog.String("client_id", clientID)) + // Validate freemium access for reconnections + if err := s.validateFreemiumAccess(ctx, clientConn, userID, chatID, true); err != nil { + log.Error("freemium validation failed for reconnection", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + clientConn.Close() + return + } + // Check if session is complete and replay unsent messages BEFORE adding client to session manager // This prevents concurrent writes: backend broadcast won't know about this client during replay if s.storage != nil { @@ -293,6 +467,29 @@ func (s *Service) handleReconnection(ctx context.Context, clientConn *websocket. slog.String("client_id", clientID), slog.Int("message_size", len(message))) + // Check if message can be forwarded based on session state + canForward, currentState, err := s.canForwardMessage(ctx, userID, chatID) + if err != nil { + log.Error("failed to check if message can be forwarded", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + // Send error to client + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify session state"}`)) + continue + } + + if !canForward { + log.Warn("message blocked - session state does not allow user input", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("session_state", currentState)) + // Send error to client indicating they cannot send messages in current state + errMsg := fmt.Sprintf(`{"error": "Cannot send messages while research is in progress", "session_state": "%s"}`, currentState) + clientConn.WriteMessage(websocket.TextMessage, []byte(errMsg)) + continue + } + // Forward to backend using synchronized write if err := s.sessionManager.WriteToBackend(userID, chatID, websocket.TextMessage, message); err != nil { log.Error("failed to forward message to backend", @@ -358,6 +555,29 @@ func (s *Service) handleClientMessages(ctx context.Context, clientConn *websocke slog.Int("message_size", len(message)), slog.Int("message_number", messageCount)) + // Check if message can be forwarded based on session state + canForward, currentState, err := s.canForwardMessage(ctx, userID, chatID) + if err != nil { + log.Error("failed to check if message can be forwarded", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("error", err.Error())) + // Send error to client + clientConn.WriteMessage(websocket.TextMessage, []byte(`{"error": "Failed to verify session state"}`)) + continue + } + + if !canForward { + log.Warn("message blocked - session state does not allow user input", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("session_state", currentState)) + // Send error to client indicating they cannot send messages in current state + errMsg := fmt.Sprintf(`{"error": "Cannot send messages while research is in progress", "session_state": "%s"}`, currentState) + clientConn.WriteMessage(websocket.TextMessage, []byte(errMsg)) + continue + } + // Forward to backend using synchronized write if err := s.sessionManager.WriteToBackend(userID, chatID, websocket.TextMessage, message); err != nil { log.Error("failed to forward message to backend", @@ -388,7 +608,7 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket deepResearchHost := os.Getenv("DEEP_RESEARCH_WS") if deepResearchHost == "" { - deepResearchHost = "165.232.133.47:3031" + deepResearchHost = "localhost:3031" log.Info("using default backend host", slog.String("host", deepResearchHost), slog.String("reason", "DEEP_RESEARCH_WS not set")) @@ -515,6 +735,23 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket } } + // Update session state in Firebase based on message type + sessionState := mapEventTypeToState(messageType) + if err := s.firebaseClient.UpdateSessionState(ctx, userID, chatID, sessionState); err != nil { + log.Error("failed to update session state in Firebase", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_type", messageType), + slog.String("session_state", sessionState), + slog.String("error", err.Error())) + } else { + log.Debug("session state updated in Firebase", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_type", messageType), + slog.String("session_state", sessionState)) + } + // Store message messageSent := false clientCount := s.sessionManager.GetClientCount(userID, chatID) @@ -524,6 +761,15 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket broadcastErr := s.sessionManager.BroadcastToClients(userID, chatID, message) messageSent = (broadcastErr == nil && clientCount > 0) + // Log detailed message info for debugging + log.Info("broadcasting message to clients", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_type", messageType), + slog.Bool("has_final_report", msg.FinalReport != ""), + slog.Int("client_count", clientCount), + slog.Bool("broadcast_success", broadcastErr == nil)) + // Store message with sent status if err := s.storage.AddMessage(userID, chatID, string(message), messageSent, messageType); err != nil { log.Error("failed to store message in storage", @@ -540,9 +786,9 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket slog.Int("client_count", clientCount)) } - // Track usage only when final_report is sent - if msg.FinalReport != "" { - log.Info("final report detected, tracking usage", + // Track usage only when research_complete event is sent + if msg.Type == "research_complete" { + log.Info("research complete event detected, tracking usage", slog.String("user_id", userID), slog.String("chat_id", chatID), slog.String("message_type", messageType)) @@ -633,6 +879,14 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket } else { // No storage, just broadcast broadcastErr := s.sessionManager.BroadcastToClients(userID, chatID, message) + + // Log detailed message info for debugging (no storage) + log.Info("broadcasting message to clients (no storage)", + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.String("message_type", messageType), + slog.Bool("has_final_report", msg.FinalReport != ""), + slog.Bool("broadcast_success", broadcastErr == nil)) if broadcastErr != nil { log.Warn("failed to broadcast message without storage", slog.String("user_id", userID), @@ -640,9 +894,9 @@ func (s *Service) handleNewConnection(ctx context.Context, clientConn *websocket slog.String("error", broadcastErr.Error())) } - // Track usage only when final_report is sent (even without storage) - if msg.FinalReport != "" { - log.Info("final report detected, tracking usage (no storage)", + // Track usage only when research_complete event is sent (even without storage) + if msg.Type == "research_complete" { + log.Info("research complete event detected, tracking usage (no storage)", slog.String("user_id", userID), slog.String("chat_id", chatID), slog.String("message_type", messageType)) diff --git a/internal/deepr/storage.go b/internal/deepr/storage.go index 54b8acb..41996b2 100644 --- a/internal/deepr/storage.go +++ b/internal/deepr/storage.go @@ -1,373 +1,12 @@ package deepr -import ( - "encoding/json" - "fmt" - "log/slog" - "os" - "path/filepath" - "sync" - "time" - - "github.com/eternisai/enchanted-proxy/internal/logger" - "github.com/google/uuid" -) - -// Storage handles persistence of deep research messages -type Storage struct { - logger *logger.Logger - storagePath string - mu sync.RWMutex -} - -// NewStorage creates a new storage instance -func NewStorage(logger *logger.Logger, storagePath string) (*Storage, error) { - // Create storage directory if it doesn't exist - if err := os.MkdirAll(storagePath, 0755); err != nil { - logger.WithComponent("deepr-storage").Error("failed to create storage directory", - slog.String("path", storagePath), - slog.String("error", err.Error())) - return nil, fmt.Errorf("failed to create storage directory: %w", err) - } - - logger.WithComponent("deepr-storage").Info("storage initialized", - slog.String("path", storagePath)) - - return &Storage{ - logger: logger, - storagePath: storagePath, - }, nil -} - -// getSessionFilePath returns the file path for a session -func (s *Storage) getSessionFilePath(userID, chatID string) string { - filename := fmt.Sprintf("session_%s_%s.json", userID, chatID) - return filepath.Join(s.storagePath, filename) -} - -// LoadSession loads a session state from disk -func (s *Storage) LoadSession(userID, chatID string) (*SessionState, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - filePath := s.getSessionFilePath(userID, chatID) - - data, err := os.ReadFile(filePath) - if err != nil { - if os.IsNotExist(err) { - // No existing session, return new one - return &SessionState{ - UserID: userID, - ChatID: chatID, - Messages: []PersistedMessage{}, - BackendConnected: false, - LastActivity: time.Now(), - FinalReportReceived: false, - ErrorOccurred: false, - }, nil - } - return nil, fmt.Errorf("failed to read session file: %w", err) - } - - var state SessionState - if err := json.Unmarshal(data, &state); err != nil { - return nil, fmt.Errorf("failed to unmarshal session: %w", err) - } - - return &state, nil -} - -// SaveSession saves a session state to disk -func (s *Storage) SaveSession(state *SessionState) error { - s.mu.Lock() - defer s.mu.Unlock() - - return s.saveSessionUnsafe(state) -} - -// loadSessionUnsafe loads a session without acquiring locks (internal use only) -func (s *Storage) loadSessionUnsafe(userID, chatID string) (*SessionState, error) { - filePath := s.getSessionFilePath(userID, chatID) - - data, err := os.ReadFile(filePath) - if err != nil { - if os.IsNotExist(err) { - // No existing session, return new one - return &SessionState{ - UserID: userID, - ChatID: chatID, - Messages: []PersistedMessage{}, - BackendConnected: false, - LastActivity: time.Now(), - FinalReportReceived: false, - ErrorOccurred: false, - }, nil - } - return nil, fmt.Errorf("failed to read session file: %w", err) - } - - var state SessionState - if err := json.Unmarshal(data, &state); err != nil { - return nil, fmt.Errorf("failed to unmarshal session: %w", err) - } - - return &state, nil -} - -// saveSessionUnsafe saves a session without acquiring locks (internal use only) -func (s *Storage) saveSessionUnsafe(state *SessionState) error { - filePath := s.getSessionFilePath(state.UserID, state.ChatID) - - state.LastActivity = time.Now() - - data, err := json.MarshalIndent(state, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal session: %w", err) - } - - if err := os.WriteFile(filePath, data, 0644); err != nil { - return fmt.Errorf("failed to write session file: %w", err) - } - - return nil -} - -// modifySession executes a mutation function on a session while holding the write lock -func (s *Storage) modifySession(userID, chatID string, mutate func(*SessionState) error) error { - s.mu.Lock() - defer s.mu.Unlock() - - state, err := s.loadSessionUnsafe(userID, chatID) - if err != nil { - return err - } - - if err := mutate(state); err != nil { - return err - } - - return s.saveSessionUnsafe(state) -} - -// AddMessage adds a new message to the session -func (s *Storage) AddMessage(userID, chatID, message string, sent bool, messageType string) error { - err := s.modifySession(userID, chatID, func(state *SessionState) error { - persistedMsg := PersistedMessage{ - ID: uuid.New().String(), - UserID: userID, - ChatID: chatID, - Message: message, - Sent: sent, - Timestamp: time.Now(), - MessageType: messageType, - } - - state.Messages = append(state.Messages, persistedMsg) - - // Check if this is a final report or error - var msg Message - if err := json.Unmarshal([]byte(message), &msg); err == nil { - if msg.FinalReport != "" { - state.FinalReportReceived = true - } - if msg.Type == "error" || msg.Error != "" { - state.ErrorOccurred = true - } - } - - s.logger.WithComponent("deepr-storage").Debug("message added to session", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.String("message_id", persistedMsg.ID), - slog.String("message_type", messageType), - slog.Bool("sent", sent), - slog.Int("total_messages", len(state.Messages))) - - return nil - }) - - if err != nil { - s.logger.WithComponent("deepr-storage").Error("failed to add message", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.String("message_type", messageType), - slog.String("error", err.Error())) - } - - return err -} - -// MarkMessageAsSent marks a specific message as sent -func (s *Storage) MarkMessageAsSent(userID, chatID, messageID string) error { - return s.modifySession(userID, chatID, func(state *SessionState) error { - for i := range state.Messages { - if state.Messages[i].ID == messageID { - state.Messages[i].Sent = true - break - } - } - return nil - }) -} - -// MarkAllMessagesAsSent marks all messages up to a certain index as sent -func (s *Storage) MarkAllMessagesAsSent(userID, chatID string) error { - return s.modifySession(userID, chatID, func(state *SessionState) error { - for i := range state.Messages { - state.Messages[i].Sent = true - } - return nil - }) -} - -// GetUnsentMessages returns all unsent messages for a session -func (s *Storage) GetUnsentMessages(userID, chatID string) ([]PersistedMessage, error) { - state, err := s.LoadSession(userID, chatID) - if err != nil { - s.logger.WithComponent("deepr-storage").Error("failed to load session for unsent messages", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.String("error", err.Error())) - return nil, err - } - - var unsent []PersistedMessage - for _, msg := range state.Messages { - if !msg.Sent { - unsent = append(unsent, msg) - } - } - - if len(unsent) > 0 { - s.logger.WithComponent("deepr-storage").Info("retrieved unsent messages", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.Int("unsent_count", len(unsent)), - slog.Int("total_messages", len(state.Messages))) - } - - return unsent, nil -} - -// GetLastUnsentMessage returns the last unsent message for a session -func (s *Storage) GetLastUnsentMessage(userID, chatID string) (*PersistedMessage, error) { - unsent, err := s.GetUnsentMessages(userID, chatID) - if err != nil { - return nil, err - } - - if len(unsent) == 0 { - return nil, nil - } - - return &unsent[len(unsent)-1], nil -} - -// UpdateBackendConnectionStatus updates the backend connection status -func (s *Storage) UpdateBackendConnectionStatus(userID, chatID string, connected bool) error { - err := s.modifySession(userID, chatID, func(state *SessionState) error { - state.BackendConnected = connected - return nil - }) - - if err == nil { - s.logger.WithComponent("deepr-storage").Info("backend connection status updated", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.Bool("connected", connected)) - } else { - s.logger.WithComponent("deepr-storage").Error("failed to update backend connection status", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.Bool("connected", connected), - slog.String("error", err.Error())) - } - - return err -} - -// IsSessionComplete checks if a session is complete (has final report or error) -func (s *Storage) IsSessionComplete(userID, chatID string) (bool, error) { - state, err := s.LoadSession(userID, chatID) - if err != nil { - s.logger.WithComponent("deepr-storage").Error("failed to load session for completion check", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.String("error", err.Error())) - return false, err - } - - isComplete := state.FinalReportReceived || state.ErrorOccurred - - s.logger.WithComponent("deepr-storage").Debug("session completion status checked", - slog.String("user_id", userID), - slog.String("chat_id", chatID), - slog.Bool("is_complete", isComplete), - slog.Bool("has_final_report", state.FinalReportReceived), - slog.Bool("has_error", state.ErrorOccurred)) - - return isComplete, nil -} - -// CleanupOldSessions removes session files older than the specified duration -func (s *Storage) CleanupOldSessions(maxAge time.Duration) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.logger.WithComponent("deepr-storage").Info("starting session cleanup", - slog.Duration("max_age", maxAge)) - - files, err := os.ReadDir(s.storagePath) - if err != nil { - s.logger.WithComponent("deepr-storage").Error("failed to read storage directory for cleanup", - slog.String("path", s.storagePath), - slog.String("error", err.Error())) - return fmt.Errorf("failed to read storage directory: %w", err) - } - - now := time.Now() - removedCount := 0 - errorCount := 0 - totalFiles := 0 - - for _, file := range files { - if file.IsDir() { - continue - } - - totalFiles++ - filePath := filepath.Join(s.storagePath, file.Name()) - info, err := file.Info() - if err != nil { - s.logger.WithComponent("deepr-storage").Error("failed to get file info during cleanup", - slog.String("file", file.Name()), - slog.String("error", err.Error())) - errorCount++ - continue - } - - fileAge := now.Sub(info.ModTime()) - if fileAge > maxAge { - if err := os.Remove(filePath); err != nil { - s.logger.WithComponent("deepr-storage").Error("failed to remove old session file", - slog.String("file", file.Name()), - slog.Duration("age", fileAge), - slog.String("error", err.Error())) - errorCount++ - } else { - s.logger.WithComponent("deepr-storage").Info("old session file removed", - slog.String("file", file.Name()), - slog.Duration("age", fileAge)) - removedCount++ - } - } - } - - s.logger.WithComponent("deepr-storage").Info("session cleanup completed", - slog.Int("total_files", totalFiles), - slog.Int("removed_count", removedCount), - slog.Int("error_count", errorCount), - slog.Duration("max_age", maxAge)) - - return nil +// MessageStorage defines the interface for storing deep research messages +// Implementations: DBStorage (database-backed, recommended) +type MessageStorage interface { + AddMessage(userID, chatID, message string, sent bool, messageType string) error + GetUnsentMessages(userID, chatID string) ([]PersistedMessage, error) + MarkMessageAsSent(userID, chatID, messageID string) error + MarkAllMessagesAsSent(userID, chatID string) error + UpdateBackendConnectionStatus(userID, chatID string, connected bool) error + IsSessionComplete(userID, chatID string) (bool, error) } diff --git a/internal/storage/pg/migrations/006_create_deep_research_messages.sql b/internal/storage/pg/migrations/006_create_deep_research_messages.sql new file mode 100644 index 0000000..63e407c --- /dev/null +++ b/internal/storage/pg/migrations/006_create_deep_research_messages.sql @@ -0,0 +1,25 @@ +-- +goose Up +CREATE TABLE IF NOT EXISTS deep_research_messages ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + chat_id TEXT NOT NULL, + session_id TEXT NOT NULL, -- user_id__chat_id combined (double underscore separator) + message TEXT NOT NULL, + message_type TEXT NOT NULL, -- status, error, research_complete, etc. + sent BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + sent_at TIMESTAMPTZ +); + +-- Indexes for efficient queries +CREATE INDEX IF NOT EXISTS idx_deep_research_messages_session_id ON deep_research_messages (session_id); +CREATE INDEX IF NOT EXISTS idx_deep_research_messages_user_chat ON deep_research_messages (user_id, chat_id); +CREATE INDEX IF NOT EXISTS idx_deep_research_messages_sent ON deep_research_messages (session_id, sent); +CREATE INDEX IF NOT EXISTS idx_deep_research_messages_created_at ON deep_research_messages (session_id, created_at); + +-- +goose Down +DROP INDEX IF EXISTS idx_deep_research_messages_created_at; +DROP INDEX IF EXISTS idx_deep_research_messages_sent; +DROP INDEX IF EXISTS idx_deep_research_messages_user_chat; +DROP INDEX IF EXISTS idx_deep_research_messages_session_id; +DROP TABLE IF EXISTS deep_research_messages; diff --git a/internal/storage/pg/queries/deep_research_messages.sql b/internal/storage/pg/queries/deep_research_messages.sql new file mode 100644 index 0000000..26856e2 --- /dev/null +++ b/internal/storage/pg/queries/deep_research_messages.sql @@ -0,0 +1,39 @@ +-- name: AddDeepResearchMessage :exec +INSERT INTO deep_research_messages (id, user_id, chat_id, session_id, message, message_type, sent, created_at, sent_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9); + +-- name: GetUnsentMessages :many +SELECT id, user_id, chat_id, session_id, message, message_type, sent, created_at, sent_at +FROM deep_research_messages +WHERE session_id = $1 AND sent = FALSE +ORDER BY created_at ASC; + +-- name: MarkMessageAsSent :exec +UPDATE deep_research_messages +SET sent = TRUE, sent_at = NOW() +WHERE id = $1; + +-- name: MarkAllMessagesAsSent :exec +UPDATE deep_research_messages +SET sent = TRUE, sent_at = NOW() +WHERE session_id = $1 AND sent = FALSE; + +-- name: GetSessionMessages :many +SELECT id, user_id, chat_id, session_id, message, message_type, sent, created_at, sent_at +FROM deep_research_messages +WHERE session_id = $1 +ORDER BY created_at ASC; + +-- name: DeleteSessionMessages :exec +DELETE FROM deep_research_messages +WHERE session_id = $1; + +-- name: GetSessionMessageCount :one +SELECT COUNT(*) as total_messages +FROM deep_research_messages +WHERE session_id = $1; + +-- name: GetUnsentMessageCount :one +SELECT COUNT(*) as unsent_count +FROM deep_research_messages +WHERE session_id = $1 AND sent = FALSE;