Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,7 @@ temp/
tests/

# Scripts
__pycache__/
__pycache__/

# Deep research session files (deprecated, now using database)
deepr_sessions/
7 changes: 6 additions & 1 deletion cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ func main() {
mcpService := mcp.NewService()
searchService := search.NewService(logger.WithComponent("search"))

// Initialize deep research storage
deeprStorage := deepr.NewDBStorage(logger.WithComponent("deepr-storage"), db.DB)

// Initialize handlers
oauthHandler := oauth.NewHandler(oauthService, logger.WithComponent("oauth"))
composioHandler := composio.NewHandler(composioService, logger.WithComponent("composio"))
Expand Down Expand Up @@ -190,6 +193,7 @@ func main() {
iapHandler: iapHandler,
mcpHandler: mcpHandler,
searchHandler: searchHandler,
deeprStorage: deeprStorage,
})

// Initialize GraphQL server for Telegram
Expand Down Expand Up @@ -290,6 +294,7 @@ type restServerInput struct {
iapHandler *iap.Handler
mcpHandler *mcp.Handler
searchHandler *search.Handler
deeprStorage deepr.MessageStorage
}

func setupRESTServer(input restServerInput) *gin.Engine {
Expand Down Expand Up @@ -364,7 +369,7 @@ func setupRESTServer(input restServerInput) *gin.Engine {
api.POST("/exa/search", input.searchHandler.PostExaSearchHandler) // POST /api/v1/exa/search (Exa AI)

// Deep Research WebSocket endpoint (protected)
api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient)) // WebSocket proxy for deep research
api.GET("/deepresearch/ws", deepr.DeepResearchHandler(input.logger, input.requestTrackingService, input.firebaseClient, input.deeprStorage)) // WebSocket proxy for deep research
}

// Protected proxy routes
Expand Down
124 changes: 124 additions & 0 deletions internal/auth/firebase_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,127 @@ func (f *FirebaseClient) SaveDeepResearchCompletion(ctx context.Context, userID,

return nil
}

// DeepResearchSessionState represents the state of a deep research session
type DeepResearchSessionState struct {
UserID string `firestore:"user_id"`
ChatID string `firestore:"chat_id"`
State string `firestore:"state"` // in_progress, clarify, error, complete
CreatedAt time.Time `firestore:"created_at"`
UpdatedAt time.Time `firestore:"updated_at"`
CompletedAt time.Time `firestore:"completed_at,omitempty"`
}

// GetSessionState retrieves the current state of a deep research session
func (f *FirebaseClient) GetSessionState(ctx context.Context, userID, chatID string) (*DeepResearchSessionState, error) {
// Use underscore as separator since forward slash is not allowed in Firestore document IDs
sessionID := fmt.Sprintf("%s__%s", userID, chatID)
docRef := f.firestoreClient.Collection("deep_research_sessions").Doc(sessionID)
doc, err := docRef.Get(ctx)

if err != nil {
// If document doesn't exist, session hasn't been created yet
if status.Code(err) == codes.NotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get session state: %w", err)
}

var state DeepResearchSessionState
if err := doc.DataTo(&state); err != nil {
return nil, fmt.Errorf("failed to parse session state: %w", err)
}

return &state, nil
}

// UpdateSessionState updates the state of a deep research session
func (f *FirebaseClient) UpdateSessionState(ctx context.Context, userID, chatID, state string) error {
// Use underscore as separator since forward slash is not allowed in Firestore document IDs
sessionID := fmt.Sprintf("%s__%s", userID, chatID)
docRef := f.firestoreClient.Collection("deep_research_sessions").Doc(sessionID)
now := time.Now()

// Check if document exists
_, err := docRef.Get(ctx)

if err != nil {
// Document doesn't exist, create new one
if status.Code(err) == codes.NotFound {
sessionState := DeepResearchSessionState{
UserID: userID,
ChatID: chatID,
State: state,
CreatedAt: now,
UpdatedAt: now,
}

// Set completed_at if state is complete
if state == "complete" {
sessionState.CompletedAt = now
}

_, err := docRef.Set(ctx, sessionState)
if err != nil {
return fmt.Errorf("failed to create session state: %w", err)
}
return nil
}
return fmt.Errorf("failed to get session state: %w", err)
}

// Document exists, update it
updateData := map[string]interface{}{
"state": state,
"updated_at": now,
}

// Set completed_at when state is complete
if state == "complete" {
updateData["completed_at"] = now
}

_, err = docRef.Set(ctx, updateData, firestore.MergeAll)
if err != nil {
return fmt.Errorf("failed to update session state: %w", err)
}

return nil
}

// GetActiveSessionsForUser retrieves all active (non-complete, non-error) sessions for a user
func (f *FirebaseClient) GetActiveSessionsForUser(ctx context.Context, userID string) ([]DeepResearchSessionState, error) {
query := f.firestoreClient.Collection("deep_research_sessions").
Where("user_id", "==", userID).
Where("state", "in", []string{"in_progress", "clarify"})

docs, err := query.Documents(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to get active sessions: %w", err)
}

var sessions []DeepResearchSessionState
for _, doc := range docs {
var session DeepResearchSessionState
if err := doc.DataTo(&session); err != nil {
return nil, fmt.Errorf("failed to parse session: %w", err)
}
sessions = append(sessions, session)
}

return sessions, nil
}

// GetCompletedSessionCountForUser returns the number of completed deep research sessions for a user
func (f *FirebaseClient) GetCompletedSessionCountForUser(ctx context.Context, userID string) (int, error) {
query := f.firestoreClient.Collection("deep_research_sessions").
Where("user_id", "==", userID).
Where("state", "==", "complete")

docs, err := query.Documents(ctx).GetAll()
if err != nil {
return 0, fmt.Errorf("failed to get completed sessions count: %w", err)
}

return len(docs), nil
}
Loading