diff --git a/.cursor/rules/testing.mdc b/.cursor/rules/testing.mdc index 5938ed34eb..797b51e1a6 100644 --- a/.cursor/rules/testing.mdc +++ b/.cursor/rules/testing.mdc @@ -36,6 +36,14 @@ describe("example", () => { }); ``` +### Helpers + +You can get mocks for emails, accounts, and rules here: + +```tsx +import { getEmail, getEmailAccount, getRule } from "@/__tests__/helpers"; +``` + ## Best Practices - Each test should be independent - Use descriptive test names diff --git a/apps/web/__tests__/ai-choose-rule.test.ts b/apps/web/__tests__/ai-choose-rule.test.ts index 0d6be0d06c..7b6e94d3b6 100644 --- a/apps/web/__tests__/ai-choose-rule.test.ts +++ b/apps/web/__tests__/ai-choose-rule.test.ts @@ -1,8 +1,8 @@ import { describe, expect, test, vi } from "vitest"; import { aiChooseRule } from "@/utils/ai/choose-rule/ai-choose-rule"; -import { type Action, ActionType, LogicalOperator } from "@prisma/client"; +import { ActionType } from "@prisma/client"; import { defaultReplyTrackerInstructions } from "@/utils/reply-tracker/consts"; -import { getEmail, getEmailAccount } from "@/__tests__/helpers"; +import { getEmail, getEmailAccount, getRule } from "@/__tests__/helpers"; // pnpm test-ai ai-choose-rule @@ -78,6 +78,7 @@ describe.runIf(isAiTest)("aiChooseRule", () => { url: null, folderName: null, delayInMinutes: null, + folderId: null, }, ]); @@ -327,26 +328,3 @@ describe.runIf(isAiTest)("aiChooseRule", () => { }); }); }); - -// helpers -function getRule(instructions: string, actions: Action[] = []) { - return { - instructions, - name: "Joke requests", - actions, - id: "id", - userId: "userId", - createdAt: new Date(), - updatedAt: new Date(), - automate: false, - runOnThreads: false, - groupId: null, - from: null, - subject: null, - body: null, - to: null, - enabled: true, - categoryFilterType: null, - conditionalOperator: LogicalOperator.AND, - }; -} diff --git a/apps/web/__tests__/helpers.ts b/apps/web/__tests__/helpers.ts index 5a11248d08..55caeedcff 100644 --- a/apps/web/__tests__/helpers.ts +++ b/apps/web/__tests__/helpers.ts @@ -1,5 +1,7 @@ import type { EmailAccountWithAI } from "@/utils/llms/types"; import type { EmailForLLM } from "@/utils/types"; +import { type Action, LogicalOperator } from "@prisma/client"; +import type { Prisma } from "@prisma/client"; export function getEmailAccount( overrides: Partial = {}, @@ -14,6 +16,9 @@ export function getEmailAccount( aiProvider: null, aiApiKey: null, }, + account: { + provider: "google", + }, }; } @@ -35,3 +40,96 @@ export function getEmail({ ...(cc && { cc }), }; } + +export function getRule(instructions: string, actions: Action[] = []) { + return { + instructions, + name: "Joke requests", + actions, + id: "id", + userId: "userId", + createdAt: new Date(), + updatedAt: new Date(), + automate: false, + runOnThreads: false, + groupId: null, + from: null, + subject: null, + body: null, + to: null, + enabled: true, + categoryFilterType: null, + conditionalOperator: LogicalOperator.AND, + }; +} + +export function getMockMessage({ + id = "msg1", + threadId = "thread1", + historyId = "12345", + from = "test@example.com", + to = "user@example.com", + subject = "Test", + snippet = "Test message", + textPlain = "Test content", + textHtml = "

Test content

", +}: { + id?: string; + threadId?: string; + historyId?: string; + from?: string; + to?: string; + subject?: string; + snippet?: string; + textPlain?: string; + textHtml?: string; +} = {}) { + return { + id, + threadId, + historyId, + headers: { + from, + to, + subject, + date: new Date().toISOString(), + }, + snippet, + textPlain, + textHtml, + attachments: [], + inline: [], + labelIds: [], + subject, + date: new Date().toISOString(), + }; +} + +export function getMockExecutedRule({ + messageId = "msg1", + threadId = "thread1", + ruleId = "rule1", + ruleName = "Test Rule", +}: { + messageId?: string; + threadId?: string; + ruleId?: string; + ruleName?: string; +} = {}): Prisma.ExecutedRuleGetPayload<{ + select: { + messageId: true; + threadId: true; + rule: { + select: { + id: true; + name: true; + }; + }; + }; +}> { + return { + messageId, + threadId, + rule: { id: ruleId, name: ruleName }, + }; +} diff --git a/apps/web/app/api/ai/analyze-sender-pattern/route.ts b/apps/web/app/api/ai/analyze-sender-pattern/route.ts index e22102d133..af41773fcb 100644 --- a/apps/web/app/api/ai/analyze-sender-pattern/route.ts +++ b/apps/web/app/api/ai/analyze-sender-pattern/route.ts @@ -1,25 +1,23 @@ import { NextResponse, after } from "next/server"; import { headers } from "next/headers"; -import type { gmail_v1 } from "@googleapis/gmail"; import { z } from "zod"; -import { getGmailClientWithRefresh } from "@/utils/gmail/client"; import { withError } from "@/utils/middleware"; import prisma from "@/utils/prisma"; -import { createScopedLogger } from "@/utils/logger"; +import { createScopedLogger, type Logger } from "@/utils/logger"; import { aiDetectRecurringPattern } from "@/utils/ai/choose-rule/ai-detect-recurring-pattern"; import { isValidInternalApiKey } from "@/utils/internal-api"; -import { getThreadMessages, getThreads } from "@/utils/gmail/thread"; import { extractEmailAddress } from "@/utils/email"; import { getEmailForLLM } from "@/utils/get-email-from-message"; import { saveLearnedPattern } from "@/utils/rule/learned-patterns"; +import { checkSenderRuleHistory } from "@/utils/rule/check-sender-rule-history"; +import { createEmailProvider } from "@/utils/email/provider"; +import type { EmailProvider } from "@/utils/email/types"; export const maxDuration = 60; -const THRESHOLD_EMAILS = 3; +const THRESHOLD_THREADS = 3; const MAX_RESULTS = 10; -const logger = createScopedLogger("api/ai/pattern-match"); - const schema = z.object({ emailAccountId: z.string(), from: z.string(), @@ -29,6 +27,8 @@ export type AnalyzeSenderPatternBody = z.infer; export const POST = withError(async (request) => { const json = await request.json(); + let logger = createScopedLogger("api/ai/pattern-match"); + if (!isValidInternalApiKey(await headers(), logger)) { logger.error("Invalid API key for sender pattern analysis", json); return NextResponse.json({ error: "Invalid API key" }); @@ -38,10 +38,12 @@ export const POST = withError(async (request) => { const { emailAccountId } = data; const from = extractEmailAddress(data.from); - logger.trace("Analyzing sender pattern", { emailAccountId, from }); + logger = logger.with({ emailAccountId, from }); + + logger.trace("Analyzing sender pattern"); // return immediately and process in background - after(() => process({ emailAccountId, from })); + after(() => process({ emailAccountId, from, logger })); return NextResponse.json({ processing: true }); }); @@ -56,20 +58,17 @@ export const POST = withError(async (request) => { async function process({ emailAccountId, from, + logger, }: { emailAccountId: string; from: string; + logger: Logger; }) { try { const emailAccount = await getEmailAccountWithRules({ emailAccountId }); - if (emailAccount?.account?.provider !== "google") { - logger.warn("Unsupported provider", { emailAccountId }); - return NextResponse.json({ success: false }, { status: 400 }); - } - if (!emailAccount) { - logger.error("Email account not found", { emailAccountId }); + logger.error("Email account not found"); return NextResponse.json({ success: false }, { status: 404 }); } @@ -83,55 +82,73 @@ async function process({ }); if (existingCheck?.patternAnalyzed) { - logger.info("Sender has already been analyzed", { from, emailAccountId }); + logger.info("Sender has already been analyzed"); return NextResponse.json({ success: true }); } const account = emailAccount.account; - if (!account?.access_token || !account?.refresh_token) { - logger.error("No Gmail account found", { emailAccountId }); + if (!account?.provider) { + logger.error("No email provider found"); return NextResponse.json({ success: false }, { status: 404 }); } - const gmail = await getGmailClientWithRefresh({ - accessToken: account.access_token, - refreshToken: account.refresh_token, - expiresAt: account.expires_at?.getTime() || null, + const provider = await createEmailProvider({ emailAccountId, + provider: account.provider, }); const threadsWithMessages = await getThreadsFromSender( - gmail, + provider, from, MAX_RESULTS, + logger, ); // If no threads found or we've detected a conversation, return early if (threadsWithMessages.length === 0) { - logger.info("No threads found from this sender", { - from, - emailAccountId, - }); + logger.info("No threads found from this sender"); // Don't record a check since we didn't run the AI analysis return NextResponse.json({ success: true }); } + if (threadsWithMessages.length < THRESHOLD_THREADS) { + logger.info("Not enough emails found from this sender", { + threadsWithMessagesCount: threadsWithMessages.length, + }); + + return NextResponse.json({ success: true }); + } + const allMessages = threadsWithMessages.flatMap( (thread) => thread.messages, ); - if (allMessages.length < THRESHOLD_EMAILS) { - logger.info("Not enough emails found from this sender", { - from, - emailAccountId, - count: allMessages.length, + const senderHistory = await checkSenderRuleHistory({ + emailAccountId, + from, + provider, + }); + + if (!senderHistory.hasConsistentRule) { + logger.info("Sender does not have consistent rule history", { + totalEmails: senderHistory.totalEmails, + uniqueRulesMatched: senderHistory.ruleMatches.size, }); + if (senderHistory.totalEmails > 0) { + await savePatternCheck({ emailAccountId, from }); + } + return NextResponse.json({ success: true }); } + logger.info("Sender has consistent rule history", { + consistentRule: senderHistory.consistentRuleName, + totalEmails: senderHistory.totalEmails, + }); + const emails = allMessages.map((message) => getEmailForLLM(message)); const patternResult = await aiDetectRecurringPattern({ @@ -141,25 +158,30 @@ async function process({ name: rule.name, instructions: rule.instructions || "", })), + consistentRuleName: senderHistory.consistentRuleName, }); if (patternResult?.matchedRule) { - await saveLearnedPattern({ - emailAccountId, - from, - ruleName: patternResult.matchedRule, - }); + // Verify the AI matched the same rule as the historical data + if (patternResult.matchedRule === senderHistory.consistentRuleName) { + await saveLearnedPattern({ + emailAccountId, + from, + ruleName: patternResult.matchedRule, + }); + } else { + logger.warn("AI suggested different rule than historical data", { + aiRule: patternResult.matchedRule, + historicalRule: senderHistory.consistentRuleName, + }); + } } await savePatternCheck({ emailAccountId, from }); return NextResponse.json({ success: true }); } catch (error) { - logger.error("Error in pattern match API", { - from, - emailAccountId, - error, - }); + logger.error("Error in pattern match API", { error }); return NextResponse.json( { error: "Failed to detect pattern" }, @@ -205,23 +227,23 @@ async function savePatternCheck({ * by excluding threads where users have replied or others have participated. */ async function getThreadsFromSender( - gmail: gmail_v1.Gmail, + provider: EmailProvider, sender: string, maxResults: number, + logger: Logger, ) { const from = extractEmailAddress(sender); - const threads = await getThreads( - `from:${from} -label:sent -label:draft`, - [], - gmail, + + const { threads } = await provider.getThreadsWithQuery({ + query: { fromEmail: from, type: "all" }, maxResults, - ); + }); const threadsWithMessages = []; // Check for conversation threads - for (const thread of threads.threads) { - const messages = await getThreadMessages(thread.id, gmail); + for (const thread of threads) { + const messages = await provider.getThreadMessages(thread.id); // Check if this is a conversation (multiple senders) const senders = messages.map((msg) => @@ -231,9 +253,7 @@ async function getThreadsFromSender( // If we found a conversation thread, skip this sender entirely if (hasOtherSenders) { - logger.info("Skipping sender pattern detection - conversation detected", { - from, - }); + logger.info("Skipping sender pattern detection - conversation detected"); return []; } diff --git a/apps/web/utils/__mocks__/email-provider.ts b/apps/web/utils/__mocks__/email-provider.ts new file mode 100644 index 0000000000..55da7031e8 --- /dev/null +++ b/apps/web/utils/__mocks__/email-provider.ts @@ -0,0 +1,134 @@ +import { vi } from "vitest"; +import type { EmailProvider } from "@/utils/email/types"; + +/** + * Creates a mock EmailProvider for testing + * + * Use this when: + * - You need a complete EmailProvider implementation + * - You're testing functions that interact with multiple EmailProvider methods + * - You want consistent default behavior across tests + * + * For simple tests that only use a few methods, consider creating a minimal mock: + * ```ts + * const mockProvider = { + * getMessage: vi.fn(), + * labelMessage: vi.fn(), + * } as unknown as EmailProvider; + * ``` + * + * @example + * ```ts + * // Basic usage + * const mockProvider = createMockEmailProvider(); + * + * // With overrides + * const mockProvider = createMockEmailProvider({ + * name: "microsoft", + * getMessage: vi.fn().mockResolvedValue(customMessage), + * }); + * + * // Setup specific behavior + * vi.mocked(mockProvider.getThreadMessages).mockResolvedValue([message1, message2]); + * ``` + */ +export const createMockEmailProvider = ( + overrides?: Partial, +): EmailProvider => ({ + name: "google", + getThreads: vi.fn().mockResolvedValue([]), + getThread: vi + .fn() + .mockResolvedValue({ + id: "thread1", + messages: [], + snippet: "Test thread snippet", + }), + getLabels: vi.fn().mockResolvedValue([]), + getLabelById: vi.fn().mockResolvedValue(null), + getMessage: vi.fn().mockResolvedValue({ + id: "msg1", + threadId: "thread1", + headers: { + from: "test@example.com", + to: "user@example.com", + subject: "Test", + date: new Date().toISOString(), + }, + snippet: "Test message", + historyId: "12345", + subject: "Test", + date: new Date().toISOString(), + textPlain: "Test content", + textHtml: "

Test content

", + attachments: [], + inline: [], + labelIds: [], + }), + getMessages: vi.fn().mockResolvedValue([]), + getSentMessages: vi.fn().mockResolvedValue([]), + getSentThreadsExcluding: vi.fn().mockResolvedValue([]), + getThreadMessages: vi.fn().mockResolvedValue([]), + getThreadMessagesInInbox: vi.fn().mockResolvedValue([]), + getPreviousConversationMessages: vi.fn().mockResolvedValue([]), + archiveThread: vi.fn().mockResolvedValue(undefined), + archiveThreadWithLabel: vi.fn().mockResolvedValue(undefined), + archiveMessage: vi.fn().mockResolvedValue(undefined), + trashThread: vi.fn().mockResolvedValue(undefined), + labelMessage: vi.fn().mockResolvedValue(undefined), + removeThreadLabel: vi.fn().mockResolvedValue(undefined), + getNeedsReplyLabel: vi.fn().mockResolvedValue(null), + getAwaitingReplyLabel: vi.fn().mockResolvedValue(null), + labelAwaitingReply: vi.fn().mockResolvedValue(undefined), + removeAwaitingReplyLabel: vi.fn().mockResolvedValue(undefined), + removeNeedsReplyLabel: vi.fn().mockResolvedValue(undefined), + draftEmail: vi.fn().mockResolvedValue({ draftId: "draft1" }), + replyToEmail: vi.fn().mockResolvedValue(undefined), + sendEmail: vi.fn().mockResolvedValue(undefined), + forwardEmail: vi.fn().mockResolvedValue(undefined), + markSpam: vi.fn().mockResolvedValue(undefined), + markRead: vi.fn().mockResolvedValue(undefined), + markReadThread: vi.fn().mockResolvedValue(undefined), + getDraft: vi.fn().mockResolvedValue(null), + deleteDraft: vi.fn().mockResolvedValue(undefined), + createLabel: vi + .fn() + .mockResolvedValue({ id: "label1", name: "Test Label", type: "user" }), + getOrCreateInboxZeroLabel: vi + .fn() + .mockResolvedValue({ id: "label1", name: "Test Label", type: "user" }), + getOriginalMessage: vi.fn().mockResolvedValue(null), + getFiltersList: vi.fn().mockResolvedValue([]), + createFilter: vi.fn().mockResolvedValue({}), + deleteFilter: vi.fn().mockResolvedValue({}), + createAutoArchiveFilter: vi.fn().mockResolvedValue({}), + getMessagesWithPagination: vi + .fn() + .mockResolvedValue({ messages: [], nextPageToken: undefined }), + getMessagesFromSender: vi + .fn() + .mockResolvedValue({ messages: [], nextPageToken: undefined }), + getMessagesBatch: vi.fn().mockResolvedValue([]), + getAccessToken: vi.fn().mockReturnValue("mock-token"), + checkIfReplySent: vi.fn().mockResolvedValue(false), + countReceivedMessages: vi.fn().mockResolvedValue(0), + getAttachment: vi.fn().mockResolvedValue({ data: "", size: 0 }), + getThreadsWithQuery: vi + .fn() + .mockResolvedValue({ threads: [], nextPageToken: undefined }), + hasPreviousCommunicationsWithSenderOrDomain: vi.fn().mockResolvedValue(false), + watchEmails: vi + .fn() + .mockResolvedValue({ expirationDate: new Date(), subscriptionId: "sub1" }), + unwatchEmails: vi.fn().mockResolvedValue(undefined), + isReplyInThread: vi.fn().mockReturnValue(false), + getThreadsFromSenderWithSubject: vi.fn().mockResolvedValue([]), + processHistory: vi.fn().mockResolvedValue(undefined), + moveThreadToFolder: vi.fn().mockResolvedValue(undefined), + ...overrides, +}); + +export const mockGmailProvider = createMockEmailProvider({ name: "google" }); +export const mockOutlookProvider = createMockEmailProvider({ + name: "microsoft", +}); diff --git a/apps/web/utils/ai/choose-rule/ai-detect-recurring-pattern.ts b/apps/web/utils/ai/choose-rule/ai-detect-recurring-pattern.ts index 25d8ce2c5c..ef3db3064e 100644 --- a/apps/web/utils/ai/choose-rule/ai-detect-recurring-pattern.ts +++ b/apps/web/utils/ai/choose-rule/ai-detect-recurring-pattern.ts @@ -20,6 +20,7 @@ export async function aiDetectRecurringPattern({ emails, emailAccount, rules, + consistentRuleName, }: { emails: EmailForLLM[]; emailAccount: EmailAccountWithAI; @@ -27,6 +28,7 @@ export async function aiDetectRecurringPattern({ name: string; instructions: string; }[]; + consistentRuleName?: string; }): Promise { // Extract the sender email from the first email // All emails should be from the same sender @@ -39,6 +41,8 @@ export async function aiDetectRecurringPattern({ Your task is to determine if emails from a specific sender should ALWAYS be matched to the same rule. +${consistentRuleName ? `IMPORTANT: Historical data shows that ALL previous emails from this sender have been matched to the "${consistentRuleName}" rule. Your task is to verify if this pattern should be learned for future emails.` : ""} + Analyze the email content to determine if this sender ALWAYS matches a specific rule. Only return a matchedRule if you're 90%+ confident all future emails from this sender will serve the same purpose; otherwise return null. @@ -46,13 +50,21 @@ A sender should only be matched to a rule if you are HIGHLY CONFIDENT that: - All future emails from this sender will serve the same purpose - The purpose clearly aligns with one specific rule - There's a consistent pattern across all sample emails provided +${consistentRuleName ? `- The content justifies always matching to the "${consistentRuleName}" rule` : ""} Examples of senders that typically match a single rule: - invoice@stripe.com → receipt rule (always sends payment confirmations) - newsletter@substack.com → newsletter rule (always sends newsletters) -- noreply@linkedin.com → social rule (always job or connection notifications) +- noreply@linkedin.com → notification rule (always sends platform notifications) +- calendar@calendly.com → calendar rule (always sends calendar invites) + +Examples of senders that should NOT have learned patterns: +- personal emails (john@gmail.com) → content varies too much -Pay close attention to the ACTUAL CONTENT of the sample emails provided. The decision should be based primarily on content analysis, not just the sender's email pattern. +Pay close attention to: +1. The sender's email domain - generic domains (gmail.com, outlook.com) rarely warrant pattern learning +2. The ACTUAL CONTENT of emails - must be consistently about the same topic/purpose +3. The sender's role - service-specific emails are good candidates, personal emails are not Be conservative in your matching. If there's any doubt, return null for "matchedRule". diff --git a/apps/web/utils/ai/choose-rule/run-rules.ts b/apps/web/utils/ai/choose-rule/run-rules.ts index 05240b5b6f..847d1c2bb1 100644 --- a/apps/web/utils/ai/choose-rule/run-rules.ts +++ b/apps/web/utils/ai/choose-rule/run-rules.ts @@ -313,20 +313,7 @@ async function analyzeSenderPatternIfAiMatch({ message: ParsedMessage; emailAccountId: string; }) { - if ( - !isTest && - result.rule && - // skip if we already matched for static reasons - // learnings only needed for rules that would run through an ai - !result.matchReasons?.some( - (reason) => - reason.type === "STATIC" || - reason.type === "GROUP" || - reason.type === "CATEGORY", - ) && - // skip if the match was "to reply" system rule - result?.rule?.systemType !== SystemType.TO_REPLY - ) { + if (shouldAnalyzeSenderPattern({ isTest, result })) { const fromAddress = extractEmailAddress(message.headers.from); if (fromAddress) { after(() => @@ -338,3 +325,28 @@ async function analyzeSenderPatternIfAiMatch({ } } } + +function shouldAnalyzeSenderPattern({ + isTest, + result, +}: { + isTest: boolean; + result: { rule?: Rule | null; matchReasons?: MatchReason[] }; +}) { + if (isTest) return false; + if (!result.rule) return false; + if (result.rule.systemType === SystemType.TO_REPLY) return false; + + // skip if we already matched for static reasons + // learnings only needed for rules that would run through an ai + if ( + result.matchReasons?.some( + (reason) => + reason.type === "STATIC" || + reason.type === "GROUP" || + reason.type === "CATEGORY", + ) + ) + return false; + return true; +} diff --git a/apps/web/utils/email/google.ts b/apps/web/utils/email/google.ts index 6abf102625..e333f41d43 100644 --- a/apps/web/utils/email/google.ts +++ b/apps/web/utils/email/google.ts @@ -498,6 +498,25 @@ export class GmailProvider implements EmailProvider { }; } + async getMessagesFromSender(options: { + senderEmail: string; + maxResults?: number; + pageToken?: string; + before?: Date; + after?: Date; + }): Promise<{ + messages: ParsedMessage[]; + nextPageToken?: string; + }> { + return this.getMessagesWithPagination({ + query: `from:${options.senderEmail}`, + maxResults: options.maxResults, + pageToken: options.pageToken, + before: options.before, + after: options.after, + }); + } + async getMessagesBatch(messageIds: string[]): Promise { return getMessagesBatch({ messageIds, diff --git a/apps/web/utils/email/microsoft.ts b/apps/web/utils/email/microsoft.ts index 7c42ef9cf3..cbbcd09042 100644 --- a/apps/web/utils/email/microsoft.ts +++ b/apps/web/utils/email/microsoft.ts @@ -57,6 +57,7 @@ import type { EmailFilter, } from "@/utils/email/types"; import { unwatchOutlook, watchOutlook } from "@/utils/outlook/watch"; +import { escapeODataString } from "@/utils/outlook/odata-escape"; const logger = createScopedLogger("outlook-provider"); @@ -174,7 +175,7 @@ export class OutlookProvider implements EmailProvider { // Add exclusion filters for TO emails for (const email of excludeToEmails) { - const escapedEmail = email.replace(/'/g, "''"); + const escapedEmail = escapeODataString(email); filters.push( `not (toRecipients/any(r: r/emailAddress/address eq '${escapedEmail}'))`, ); @@ -182,7 +183,7 @@ export class OutlookProvider implements EmailProvider { // Add exclusion filters for FROM emails for (const email of excludeFromEmails) { - const escapedEmail = email.replace(/'/g, "''"); + const escapedEmail = escapeODataString(email); filters.push(`not (from/emailAddress/address eq '${escapedEmail}')`); } @@ -534,6 +535,26 @@ export class OutlookProvider implements EmailProvider { }; } + async getMessagesFromSender(options: { + senderEmail: string; + maxResults?: number; + pageToken?: string; + before?: Date; + after?: Date; + }): Promise<{ + messages: ParsedMessage[]; + nextPageToken?: string; + }> { + const senderFilter = `from/emailAddress/address eq '${escapeODataString(options.senderEmail)}'`; + return this.getMessagesWithPagination({ + query: senderFilter, + maxResults: options.maxResults, + pageToken: options.pageToken, + before: options.before, + after: options.after, + }); + } + async getMessagesBatch(messageIds: string[]): Promise { // For Outlook, we need to fetch messages individually since there's no batch endpoint const messagePromises = messageIds.map((messageId) => @@ -660,7 +681,7 @@ export class OutlookProvider implements EmailProvider { // Add other filters if (query?.fromEmail) { // Escape single quotes in email address - const escapedEmail = query.fromEmail.replace(/'/g, "''"); + const escapedEmail = escapeODataString(query.fromEmail); filters.push(`from/emailAddress/address eq '${escapedEmail}'`); } @@ -811,7 +832,7 @@ export class OutlookProvider implements EmailProvider { .getClient() .api("/me/messages") .filter( - `from/emailAddress/address eq '${options.from}' and receivedDateTime lt ${options.date.toISOString()}`, + `from/emailAddress/address eq '${escapeODataString(options.from)}' and receivedDateTime lt ${options.date.toISOString()}`, ) .top(2) .select("id") diff --git a/apps/web/utils/email/types.ts b/apps/web/utils/email/types.ts index f3394a0598..bc837281b5 100644 --- a/apps/web/utils/email/types.ts +++ b/apps/web/utils/email/types.ts @@ -121,6 +121,16 @@ export interface EmailProvider { messages: ParsedMessage[]; nextPageToken?: string; }>; + getMessagesFromSender(options: { + senderEmail: string; + maxResults?: number; + pageToken?: string; + before?: Date; + after?: Date; + }): Promise<{ + messages: ParsedMessage[]; + nextPageToken?: string; + }>; getMessagesBatch(messageIds: string[]): Promise; getAccessToken(): string; checkIfReplySent(senderEmail: string): Promise; diff --git a/apps/web/utils/rule/check-sender-rule-history.test.ts b/apps/web/utils/rule/check-sender-rule-history.test.ts new file mode 100644 index 0000000000..dfbc775eb7 --- /dev/null +++ b/apps/web/utils/rule/check-sender-rule-history.test.ts @@ -0,0 +1,401 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { checkSenderRuleHistory } from "@/utils/rule/check-sender-rule-history"; +import prisma from "@/utils/__mocks__/prisma"; +import { createMockEmailProvider } from "@/utils/__mocks__/email-provider"; +import { getMockMessage, getMockExecutedRule } from "@/__tests__/helpers"; + +vi.mock("@/utils/prisma"); + +describe("checkSenderRuleHistory", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockProvider = createMockEmailProvider(); + + it("should return no consistent rule when no messages found from sender", async () => { + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: [], + nextPageToken: undefined, + }); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(0); + expect(result.hasConsistentRule).toBe(false); + expect(result.consistentRuleName).toBeUndefined(); + expect(mockProvider.getMessagesFromSender).toHaveBeenCalledWith({ + senderEmail: "test@example.com", + maxResults: 50, + }); + }); + + it("should return consistent rule when all emails match the same rule", async () => { + const mockMessages = [ + getMockMessage({ + id: "msg1", + threadId: "thread1", + subject: "Test 1", + snippet: "Test message 1", + textPlain: "Test content 1", + textHtml: "

Test content 1

", + }), + getMockMessage({ + id: "msg2", + threadId: "thread2", + subject: "Test 2", + snippet: "Test message 2", + textPlain: "Test content 2", + textHtml: "

Test content 2

", + }), + getMockMessage({ + id: "msg3", + threadId: "thread3", + subject: "Test 3", + snippet: "Test message 3", + textPlain: "Test content 3", + textHtml: "

Test content 3

", + }), + ]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + const mockExecutedRules = [ + getMockExecutedRule({ + messageId: "msg1", + threadId: "thread1", + ruleId: "rule1", + ruleName: "Newsletter", + }), + getMockExecutedRule({ + messageId: "msg2", + threadId: "thread2", + ruleId: "rule1", + ruleName: "Newsletter", + }), + getMockExecutedRule({ + messageId: "msg3", + threadId: "thread3", + ruleId: "rule1", + ruleName: "Newsletter", + }), + ]; + + prisma.executedRule.findMany.mockResolvedValue(mockExecutedRules as any); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(3); + expect(result.hasConsistentRule).toBe(true); + expect(result.consistentRuleName).toBe("Newsletter"); + expect(result.ruleMatches.size).toBe(1); + + // Verify database query was called with correct message IDs + expect(prisma.executedRule.findMany).toHaveBeenCalledWith({ + where: { + emailAccountId: "test-email-account", + status: "APPLIED", + messageId: { in: ["msg1", "msg2", "msg3"] }, + rule: { + enabled: true, + }, + }, + select: { + messageId: true, + threadId: true, + rule: { + select: { + id: true, + name: true, + }, + }, + }, + }); + }); + + it("should return no consistent rule when emails match different rules", async () => { + const mockMessages = [ + getMockMessage({ + id: "msg1", + threadId: "thread1", + subject: "Test 1", + snippet: "Test message 1", + }), + getMockMessage({ + id: "msg2", + threadId: "thread2", + subject: "Test 2", + snippet: "Test message 2", + }), + getMockMessage({ + id: "msg3", + threadId: "thread3", + subject: "Test 3", + snippet: "Test message 3", + }), + ]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + const mockExecutedRules = [ + getMockExecutedRule({ + messageId: "msg1", + threadId: "thread1", + ruleId: "rule1", + ruleName: "Newsletter", + }), + getMockExecutedRule({ + messageId: "msg2", + threadId: "thread2", + ruleId: "rule2", + ruleName: "Calendar", + }), + getMockExecutedRule({ + messageId: "msg3", + threadId: "thread3", + ruleId: "rule1", + ruleName: "Newsletter", + }), + ]; + + prisma.executedRule.findMany.mockResolvedValue(mockExecutedRules as any); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(3); + expect(result.hasConsistentRule).toBe(false); + expect(result.consistentRuleName).toBeUndefined(); + expect(result.ruleMatches.size).toBe(2); + + // Verify both rules are counted + const newsletterRule = result.ruleMatches.get("rule1"); + const calendarRule = result.ruleMatches.get("rule2"); + expect(newsletterRule?.count).toBe(2); + expect(calendarRule?.count).toBe(1); + }); + + it("should handle messages with no executed rules", async () => { + const mockMessages = [ + getMockMessage({ + id: "msg1", + threadId: "thread1", + subject: "Test 1", + snippet: "Test message 1", + }), + getMockMessage({ + id: "msg2", + threadId: "thread2", + subject: "Test 2", + snippet: "Test message 2", + }), + ]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + // No executed rules found for these messages + prisma.executedRule.findMany.mockResolvedValue([]); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(2); // 2 messages from sender + expect(result.hasConsistentRule).toBe(false); // No rules applied + expect(result.consistentRuleName).toBeUndefined(); + expect(result.ruleMatches.size).toBe(0); + }); + + it("should handle getMessagesFromSender errors gracefully", async () => { + // Mock getMessagesFromSender to throw an error + vi.mocked(mockProvider.getMessagesFromSender).mockRejectedValue( + new Error("Failed to fetch messages from provider"), + ); + + await expect( + checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }), + ).rejects.toThrow("Failed to fetch messages from provider"); + }); + + it("should handle database query errors gracefully", async () => { + const mockMessages = [getMockMessage({ id: "msg1", threadId: "thread1" })]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + // Mock database error + prisma.executedRule.findMany.mockRejectedValue( + new Error("Database connection failed"), + ); + + await expect( + checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }), + ).rejects.toThrow("Database connection failed"); + }); + + it("should extract email address from complex from field", async () => { + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: [], + nextPageToken: undefined, + }); + + await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "John Doe ", // Complex from field + provider: mockProvider, + }); + + expect(mockProvider.getMessagesFromSender).toHaveBeenCalledWith({ + senderEmail: "john@example.com", // Should extract just the email + maxResults: 50, + }); + }); + + it("should handle executed rules without associated rule (deleted rules)", async () => { + const mockMessages = [ + getMockMessage({ id: "msg1", threadId: "thread1" }), + getMockMessage({ id: "msg2", threadId: "thread2" }), + ]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + const mockExecutedRules = [ + getMockExecutedRule({ + messageId: "msg1", + threadId: "thread1", + ruleId: "rule1", + ruleName: "Newsletter", + }), + // Skip msg2 - simulates no executed rule found (rule was deleted) + ]; + + prisma.executedRule.findMany.mockResolvedValue(mockExecutedRules as any); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(2); + expect(result.ruleMatches.size).toBe(1); // Only one executed rule found + expect(result.hasConsistentRule).toBe(true); // Only one rule type exists + }); + + it("should handle duplicate message IDs correctly", async () => { + const mockMessages = [ + getMockMessage({ id: "msg1", threadId: "thread1" }), + getMockMessage({ id: "msg2", threadId: "thread2" }), + getMockMessage({ id: "msg3", threadId: "thread3" }), + ]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + const mockExecutedRules = [ + getMockExecutedRule({ + messageId: "msg1", + threadId: "thread1", + ruleId: "rule1", + ruleName: "Newsletter", + }), + getMockExecutedRule({ + messageId: "msg1", + threadId: "thread1", + ruleId: "rule1", + ruleName: "Newsletter", + }), // Duplicate + getMockExecutedRule({ + messageId: "msg2", + threadId: "thread2", + ruleId: "rule1", + ruleName: "Newsletter", + }), + ]; + + prisma.executedRule.findMany.mockResolvedValue(mockExecutedRules as any); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(3); + expect(result.ruleMatches.size).toBe(1); + const newsletterRule = result.ruleMatches.get("rule1"); + expect(newsletterRule?.count).toBe(2); // Should not double-count msg1 + }); + + it("should handle partial rule coverage", async () => { + const mockMessages = [ + getMockMessage({ id: "msg1", threadId: "thread1" }), + getMockMessage({ id: "msg2", threadId: "thread2" }), + ]; + + vi.mocked(mockProvider.getMessagesFromSender).mockResolvedValue({ + messages: mockMessages, + nextPageToken: undefined, + }); + + // Only one message has an executed rule + const mockExecutedRules = [ + getMockExecutedRule({ + messageId: "msg1", + threadId: "thread1", + ruleId: "rule1", + ruleName: "Newsletter", + }), + ]; + + prisma.executedRule.findMany.mockResolvedValue(mockExecutedRules as any); + + const result = await checkSenderRuleHistory({ + emailAccountId: "test-email-account", + from: "test@example.com", + provider: mockProvider, + }); + + expect(result.totalEmails).toBe(2); + expect(result.ruleMatches.size).toBe(1); + expect(result.hasConsistentRule).toBe(true); // Single rule type + expect(result.consistentRuleName).toBe("Newsletter"); + }); +}); diff --git a/apps/web/utils/rule/check-sender-rule-history.ts b/apps/web/utils/rule/check-sender-rule-history.ts new file mode 100644 index 0000000000..70e5df69ac --- /dev/null +++ b/apps/web/utils/rule/check-sender-rule-history.ts @@ -0,0 +1,126 @@ +import sumBy from "lodash/sumBy"; +import prisma from "@/utils/prisma"; +import { createScopedLogger } from "@/utils/logger"; +import type { EmailProvider } from "@/utils/email/types"; +import { extractEmailAddress } from "@/utils/email"; +import { ExecutedRuleStatus } from "@prisma/client"; + +export interface SenderRuleHistory { + totalEmails: number; + ruleMatches: Map; + hasConsistentRule: boolean; + consistentRuleName?: string; +} + +/** + * Checks the historical rule matches for a specific sender + * Returns information about which rules have been applied to this sender's emails + */ +export async function checkSenderRuleHistory({ + emailAccountId, + from, + provider, +}: { + emailAccountId: string; + from: string; + provider: EmailProvider; +}): Promise { + const logger = createScopedLogger("checkSenderRuleHistory").with({ + emailAccountId, + from, + }); + const senderEmail = extractEmailAddress(from); + + logger.info("Checking sender rule history"); + + const { messages } = await provider.getMessagesFromSender({ + senderEmail, + maxResults: 50, + }); + + logger.info("Found messages from sender", { totalMessages: messages.length }); + + if (messages.length === 0) { + return { + totalEmails: 0, + ruleMatches: new Map(), + hasConsistentRule: false, + }; + } + + const messageIds = messages.map((message) => message.id); + + const executedRules = await prisma.executedRule.findMany({ + where: { + emailAccountId, + status: ExecutedRuleStatus.APPLIED, + messageId: { in: messageIds }, + rule: { enabled: true }, + }, + select: { + messageId: true, + threadId: true, + rule: { select: { id: true, name: true } }, + }, + }); + + logger.info("Found executed rules for sender messages", { + totalExecutedRules: executedRules.length, + }); + + // Process the results + const ruleMatches = new Map(); + const processedMessageIds = new Set(); + + for (const executedRule of executedRules) { + if (!executedRule.rule) continue; + + // Avoid double-counting if we match both messageId and threadId for the same message + const messageKey = executedRule.messageId || executedRule.threadId; + if (!messageKey || processedMessageIds.has(messageKey)) continue; + + processedMessageIds.add(messageKey); + + const existing = ruleMatches.get(executedRule.rule.id); + if (existing) { + existing.count++; + } else { + ruleMatches.set(executedRule.rule.id, { + ruleName: executedRule.rule.name, + count: 1, + }); + } + } + + const totalEmailsFromSender = messages.length; + const totalRuleMatches = sumBy( + Array.from(ruleMatches.values()), + (rule) => rule.count, + ); + + // Check if there's a consistent rule + let hasConsistentRule = false; + let consistentRuleName: string | undefined; + + if (totalRuleMatches > 0 && ruleMatches.size === 1) { + // All rule executions were for the same rule + const [[, ruleInfo]] = Array.from(ruleMatches.entries()); + hasConsistentRule = true; + consistentRuleName = ruleInfo.ruleName; + } + + logger.info("Sender rule history analysis complete", { + totalEmailsFromSender, + totalRuleMatches, + uniqueRulesMatched: ruleMatches.size, + hasConsistentRule, + consistentRuleName, + }); + + return { + totalEmails: totalEmailsFromSender, + ruleMatches, + hasConsistentRule, + consistentRuleName, + }; +} diff --git a/version.txt b/version.txt index 1e8349cd6d..2879391475 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v2.10.2 +v2.10.3