diff --git a/.vscode/settings.json b/.vscode/settings.json index 5506a35787..287fb3e19a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -15,5 +15,6 @@ "emmet.showExpandedAbbreviation": "never", "[prisma]": { "editor.defaultFormatter": "Prisma.prisma" - } + }, + "prisma.pinToPrisma6": true } \ No newline at end of file diff --git a/apps/web/app/(app)/[emailAccountId]/assistant/RuleForm.tsx b/apps/web/app/(app)/[emailAccountId]/assistant/RuleForm.tsx index 33b1f29971..2e32b7c60e 100644 --- a/apps/web/app/(app)/[emailAccountId]/assistant/RuleForm.tsx +++ b/apps/web/app/(app)/[emailAccountId]/assistant/RuleForm.tsx @@ -518,7 +518,7 @@ export function RuleForm({ )} diff --git a/apps/web/app/(app)/[emailAccountId]/assistant/group/LearnedPatterns.tsx b/apps/web/app/(app)/[emailAccountId]/assistant/group/LearnedPatterns.tsx index ce78cb4b7e..8996eb212d 100644 --- a/apps/web/app/(app)/[emailAccountId]/assistant/group/LearnedPatterns.tsx +++ b/apps/web/app/(app)/[emailAccountId]/assistant/group/LearnedPatterns.tsx @@ -73,7 +73,7 @@ export function LearnedPatternsDialog({ - + Learned Patterns diff --git a/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx b/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx index fad0132c15..bcd41bf243 100644 --- a/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx +++ b/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx @@ -96,21 +96,6 @@ function ViewGroupInner({ groupId }: { groupId: string }) { Add pattern - - {!!group?.items?.length && ( - - )} )} diff --git a/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx b/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx index 32e66884e3..51bd127434 100644 --- a/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx +++ b/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx @@ -115,8 +115,11 @@ function WritingStyleDialog({ registerProps={register("writingStyle")} error={errors.writingStyle} placeholder="Typical Length: 2-3 sentences + Formality: Informal but professional + Common Greeting: Hey, + Notable Traits: - Uses contractions frequently - Concise and direct responses diff --git a/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx b/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx index 8a08080c30..3ccae88683 100644 --- a/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx +++ b/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx @@ -20,6 +20,7 @@ import { AlertBasic } from "@/components/Alert"; import { Button } from "@/components/ui/button"; import { useSearchParams } from "next/navigation"; import { markNotColdEmailAction } from "@/utils/actions/cold-email"; +import { toggleRuleAction } from "@/utils/actions/rule"; import { Checkbox } from "@/components/Checkbox"; import { useToggleSelect } from "@/hooks/useToggleSelect"; import { ViewEmailButton } from "@/components/ViewEmailButton"; @@ -27,9 +28,9 @@ import { EmailMessageCellWithData } from "@/components/EmailMessageCell"; import { EnableFeatureCard } from "@/components/EnableFeatureCard"; import { toastError, toastSuccess } from "@/components/Toast"; import { useAccount } from "@/providers/EmailAccountProvider"; -import { prefixPath } from "@/utils/path"; import { useRules } from "@/hooks/useRules"; import { isColdEmailBlockerEnabled } from "@/utils/cold-email/cold-email-blocker-enabled"; +import { SystemType } from "@/generated/prisma/enums"; export function ColdEmailList() { const searchParams = useSearchParams(); @@ -187,18 +188,36 @@ function Row({ function NoColdEmails() { const { emailAccountId } = useAccount(); - const { data: rules } = useRules(); + const { data: rules, mutate: mutateRules } = useRules(); + + const { executeAsync: enableColdEmailBlocker } = useAction( + toggleRuleAction.bind(null, emailAccountId), + { + onSuccess: () => { + toastSuccess({ description: "Cold email blocker enabled!" }); + mutateRules(); + }, + onError: () => { + toastError({ description: "Error enabling cold email blocker" }); + }, + }, + ); if (!isColdEmailBlockerEnabled(rules || [])) { return (
{ + await enableColdEmailBlocker({ + systemType: SystemType.COLD_EMAIL, + enabled: true, + }); + }} hideBorder />
diff --git a/apps/web/app/api/ai/analyze-sender-pattern/route.ts b/apps/web/app/api/ai/analyze-sender-pattern/route.ts index 23d53d3134..dc9bfb6082 100644 --- a/apps/web/app/api/ai/analyze-sender-pattern/route.ts +++ b/apps/web/app/api/ai/analyze-sender-pattern/route.ts @@ -10,6 +10,7 @@ import { isValidInternalApiKey } from "@/utils/internal-api"; import { extractEmailAddress } from "@/utils/email"; import { getEmailForLLM } from "@/utils/get-email-from-message"; import { saveLearnedPattern } from "@/utils/rule/learned-patterns"; +import { GroupItemSource } from "@/generated/prisma/enums"; import { checkSenderRuleHistory } from "@/utils/rule/check-sender-rule-history"; import { createEmailProvider } from "@/utils/email/provider"; import type { EmailProvider } from "@/utils/email/types"; @@ -177,12 +178,24 @@ async function process({ if (patternResult?.matchedRule) { // Verify the AI matched the same rule as the historical data if (patternResult.matchedRule === senderHistory.consistentRuleName) { - await saveLearnedPattern({ - emailAccountId, - from, - ruleName: patternResult.matchedRule, - logger, - }); + const matchedRule = emailAccount.rules.find( + (rule) => rule.name === patternResult.matchedRule, + ); + + if (matchedRule) { + await saveLearnedPattern({ + emailAccountId, + from, + ruleId: matchedRule.id, + logger, + source: GroupItemSource.AI, + }); + } else { + logger.error("Matched rule not found in email account rules", { + ruleName: patternResult.matchedRule, + availableRules: emailAccount.rules.map((r) => r.name), + }); + } } else { logger.warn("AI suggested different rule than historical data", { aiRule: patternResult.matchedRule, diff --git a/apps/web/app/api/google/webhook/process-label-removed-event.test.ts b/apps/web/app/api/google/webhook/process-label-removed-event.test.ts index 39c895222e..10c95dbef2 100644 --- a/apps/web/app/api/google/webhook/process-label-removed-event.test.ts +++ b/apps/web/app/api/google/webhook/process-label-removed-event.test.ts @@ -1,20 +1,27 @@ import { vi, describe, it, expect, beforeEach } from "vitest"; -import { ColdEmailStatus } from "@/generated/prisma/enums"; import { HistoryEventType } from "./types"; import { handleLabelRemovedEvent } from "./process-label-removed-event"; import type { gmail_v1 } from "@googleapis/gmail"; -import { saveLearnedPatterns } from "@/utils/rule/learned-patterns"; -import prisma from "@/utils/__mocks__/prisma"; +import { saveLearnedPattern } from "@/utils/rule/learned-patterns"; import { createScopedLogger } from "@/utils/logger"; +import { GroupItemSource, SystemType } from "@/generated/prisma/enums"; +import prisma from "@/utils/prisma"; const logger = createScopedLogger("test"); vi.mock("server-only", () => ({})); // Mock dependencies -vi.mock("@/utils/prisma"); +vi.mock("@/utils/prisma", () => ({ + default: { + rule: { + findFirst: vi.fn(), + }, + }, +})); + vi.mock("@/utils/rule/learned-patterns", () => ({ - saveLearnedPatterns: vi.fn().mockResolvedValue(undefined), + saveLearnedPattern: vi.fn().mockResolvedValue(undefined), })); vi.mock("@/utils/gmail/label", () => ({ @@ -89,88 +96,42 @@ describe("process-label-removed-event", () => { }; describe("handleLabelRemovedEvent", () => { - it("should process Cold Email label removal and update ColdEmail status", async () => { - prisma.coldEmail.upsert.mockResolvedValue({} as any); + it("should process Cold Email label removal and call saveLearnedPattern with exclude: true", async () => { + vi.mocked(prisma.rule.findFirst).mockResolvedValue({ + id: "rule-123", + systemType: SystemType.COLD_EMAIL, + } as any); const historyItem = createLabelRemovedHistoryItem(); - console.log("Test data:", JSON.stringify(historyItem.item, null, 2)); - - try { - await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); - } catch (error) { - console.error("Function error:", error); - throw error; - } - - expect(prisma.coldEmail.upsert).toHaveBeenCalledWith({ - where: { - emailAccountId_fromEmail: { - emailAccountId: "email-account-id", - fromEmail: "sender@example.com", - }, - }, - update: { - status: ColdEmailStatus.USER_REJECTED_COLD, - }, - create: { - status: ColdEmailStatus.USER_REJECTED_COLD, - fromEmail: "sender@example.com", - emailAccountId: "email-account-id", - messageId: "123", - threadId: "thread-123", - }, - }); - }); - - it("should skip learning when Newsletter label is removed (only Cold Email is supported)", async () => { - const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ - "label-2", - ]); - await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); - expect(saveLearnedPatterns).not.toHaveBeenCalled(); - }); - - it("should skip learning when To Reply label is removed (only Cold Email is supported)", async () => { - const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ - "label-4", - ]); - - await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); - - expect(saveLearnedPatterns).not.toHaveBeenCalled(); - }); - - it("should skip learning when no executed rule exists (only Cold Email is supported)", async () => { - const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ - "label-2", - ]); - - await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); - - expect(saveLearnedPatterns).not.toHaveBeenCalled(); + expect(saveLearnedPattern).toHaveBeenCalledWith({ + emailAccountId: "email-account-id", + from: "sender@example.com", + ruleId: "rule-123", + exclude: true, + logger: expect.anything(), + messageId: "123", + threadId: "thread-123", + reason: "Label removed", + source: GroupItemSource.LABEL_REMOVED, + }); }); - it("should skip learning when no matching LABEL action is found (only Cold Email is supported)", async () => { - const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ - "label-2", - ]); + it("should skip learning when To Reply label is removed (not a learnable rule)", async () => { + vi.mocked(prisma.rule.findFirst).mockResolvedValue({ + id: "rule-456", + systemType: SystemType.TO_REPLY, + } as any); - await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); - - expect(saveLearnedPatterns).not.toHaveBeenCalled(); - }); - - it("should handle multiple label removals in a single event (only Cold Email is supported)", async () => { const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ - "label-3", + "label-4", ]); await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); - expect(saveLearnedPatterns).not.toHaveBeenCalled(); + expect(saveLearnedPattern).not.toHaveBeenCalled(); }); it("should skip processing when only system labels are removed", async () => { @@ -183,7 +144,7 @@ describe("process-label-removed-event", () => { // Should not try to fetch the message when only system labels removed expect(mockProvider.getMessage).not.toHaveBeenCalled(); - expect(prisma.coldEmail.upsert).not.toHaveBeenCalled(); + expect(saveLearnedPattern).not.toHaveBeenCalled(); }); it("should skip processing when DRAFT label is removed (prevents 404 errors)", async () => { @@ -194,9 +155,8 @@ describe("process-label-removed-event", () => { await handleLabelRemovedEvent(historyItem, defaultOptions, logger); - // Should not try to fetch the message (which would fail with 404) expect(mockProvider.getMessage).not.toHaveBeenCalled(); - expect(prisma.coldEmail.upsert).not.toHaveBeenCalled(); + expect(saveLearnedPattern).not.toHaveBeenCalled(); }); it("should skip processing when messageId is missing", async () => { @@ -207,7 +167,7 @@ describe("process-label-removed-event", () => { await handleLabelRemovedEvent(historyItem, defaultOptions, logger); - expect(prisma.coldEmail.upsert).not.toHaveBeenCalled(); + expect(saveLearnedPattern).not.toHaveBeenCalled(); }); it("should skip processing when threadId is missing", async () => { @@ -218,7 +178,46 @@ describe("process-label-removed-event", () => { await handleLabelRemovedEvent(historyItem, defaultOptions, logger); - expect(prisma.coldEmail.upsert).not.toHaveBeenCalled(); + expect(saveLearnedPattern).not.toHaveBeenCalled(); + }); + + it("should handle multiple label removals in a single event", async () => { + vi.mocked(prisma.rule.findFirst) + .mockResolvedValueOnce({ + id: "rule-1", + systemType: SystemType.COLD_EMAIL, + } as any) + .mockResolvedValueOnce({ + id: "rule-2", + systemType: SystemType.NEWSLETTER, + } as any); + + const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ + "label-1", + "label-2", + ]); + + await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); + + expect(saveLearnedPattern).toHaveBeenCalledTimes(2); + expect(saveLearnedPattern).toHaveBeenCalledWith( + expect.objectContaining({ ruleId: "rule-1" }), + ); + expect(saveLearnedPattern).toHaveBeenCalledWith( + expect.objectContaining({ ruleId: "rule-2" }), + ); + }); + + it("should skip learning when no rule is found for the removed label", async () => { + vi.mocked(prisma.rule.findFirst).mockResolvedValue(null); + + const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [ + "unknown-label", + ]); + + await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger); + + expect(saveLearnedPattern).not.toHaveBeenCalled(); }); }); }); diff --git a/apps/web/app/api/google/webhook/process-label-removed-event.ts b/apps/web/app/api/google/webhook/process-label-removed-event.ts index 3e2bbb5b5e..2ae4c9611a 100644 --- a/apps/web/app/api/google/webhook/process-label-removed-event.ts +++ b/apps/web/app/api/google/webhook/process-label-removed-event.ts @@ -1,12 +1,13 @@ import type { gmail_v1 } from "@googleapis/gmail"; -import prisma from "@/utils/prisma"; -import { ColdEmailStatus, SystemType } from "@/generated/prisma/enums"; +import { GroupItemSource, ActionType } from "@/generated/prisma/enums"; +import { saveLearnedPattern } from "@/utils/rule/learned-patterns"; import { extractEmailAddress } from "@/utils/email"; import type { EmailAccountWithAI } from "@/utils/llms/types"; import type { EmailProvider } from "@/utils/email/types"; import { GmailLabel } from "@/utils/gmail/label"; -import { getRuleLabel } from "@/utils/rule/consts"; +import { shouldLearnFromLabelRemoval } from "@/utils/rule/consts"; import type { Logger } from "@/utils/logger"; +import prisma from "@/utils/prisma"; import { isGmailRateLimitExceededError, isGmailQuotaExceededError, @@ -115,22 +116,10 @@ export async function handleLabelRemovedEvent( return; } - const labels = await provider.getLabels(); - for (const labelId of removedLabelIds) { - const label = labels?.find((l) => l.id === labelId); - const labelName = label?.name; - - if (!labelName) { - logger.info("Skipping label removal - missing label name", { - labelId, - }); - continue; - } - try { await learnFromRemovedLabel({ - labelName, + labelId, sender, messageId, threadId, @@ -140,7 +129,7 @@ export async function handleLabelRemovedEvent( } catch (error) { logger.error("Error learning from label removal", { error, - labelName, + labelId, removedLabelIds, }); } @@ -148,21 +137,21 @@ export async function handleLabelRemovedEvent( } async function learnFromRemovedLabel({ - labelName, + labelId, sender, messageId, threadId, emailAccountId, logger, }: { - labelName: string; + labelId: string; sender: string | null; messageId: string; threadId: string; emailAccountId: string; logger: Logger; }) { - logger = logger.with({ labelName, sender }); + logger = logger.with({ labelId, sender }); // Can't learn patterns without knowing who to exclude if (!sender) { @@ -170,28 +159,41 @@ async function learnFromRemovedLabel({ return; } - if (labelName === getRuleLabel(SystemType.COLD_EMAIL)) { - logger.info("Processing Cold Email label removal"); - - await prisma.coldEmail.upsert({ - where: { - emailAccountId_fromEmail: { - emailAccountId, - fromEmail: sender, + // Find rule with matching label action + const rule = await prisma.rule.findFirst({ + where: { + emailAccountId, + systemType: { not: null }, + actions: { + some: { + labelId: labelId, + type: ActionType.LABEL, }, }, - update: { - status: ColdEmailStatus.USER_REJECTED_COLD, - }, - create: { - status: ColdEmailStatus.USER_REJECTED_COLD, - fromEmail: sender, - emailAccountId, - messageId, - threadId, - }, - }); + }, + select: { id: true, systemType: true }, + }); + if (!rule?.systemType || !shouldLearnFromLabelRemoval(rule.systemType)) { + logger.info("Label removal does not match a learnable system rule", { + systemType: rule?.systemType, + }); return; } + + logger.info("Processing label removal for learning", { + systemType: rule.systemType, + }); + + await saveLearnedPattern({ + emailAccountId, + from: sender, + ruleId: rule.id, + exclude: true, + logger, + messageId, + threadId, + reason: "Label removed", + source: GroupItemSource.LABEL_REMOVED, + }); } diff --git a/apps/web/app/api/resend/summary/route.ts b/apps/web/app/api/resend/summary/route.ts index 8679286cd9..7fe1918494 100644 --- a/apps/web/app/api/resend/summary/route.ts +++ b/apps/web/app/api/resend/summary/route.ts @@ -7,7 +7,7 @@ import { env } from "@/env"; import { hasCronSecret } from "@/utils/cron"; import { captureException } from "@/utils/error"; import prisma from "@/utils/prisma"; -import { ThreadTrackerType } from "@/generated/prisma/enums"; +import { SystemType, ThreadTrackerType } from "@/generated/prisma/enums"; import type { Logger } from "@/utils/logger"; import { getMessagesBatch } from "@/utils/gmail/message"; import { decodeSnippet } from "@/utils/gmail/decode"; @@ -110,7 +110,6 @@ async function sendEmail({ where: { id: emailAccountId }, select: { email: true, - coldEmails: { where: { createdAt: { gt: cutOffDate } } }, account: { select: { access_token: true, @@ -124,84 +123,78 @@ async function sendEmail({ return { success: false }; } - if (emailAccount) { - logger.info("Email account found"); - } else { - logger.error("Email account not found or cutoff date is in the future", { - cutOffDate, - }); - return { success: true }; - } - - // Get counts and recent threads for each type - const [ - counts, - needsReply, - awaitingReply, - // needsAction - ] = await Promise.all([ - // total count - // NOTE: should really be distinct by threadId. this will cause a mismatch in some cases - prisma.threadTracker.groupBy({ - by: ["type"], - where: { - emailAccountId, - resolved: false, - }, - _count: true, - }), - // needs reply - prisma.threadTracker.findMany({ - where: { + const coldEmailRule = await prisma.rule.findUnique({ + where: { + emailAccountId_systemType: { emailAccountId, - type: ThreadTrackerType.NEEDS_REPLY, - resolved: false, + systemType: SystemType.COLD_EMAIL, }, - orderBy: { sentAt: "desc" }, - take: 20, - distinct: ["threadId"], - }), - // awaiting reply - prisma.threadTracker.findMany({ - where: { - emailAccountId, - type: ThreadTrackerType.AWAITING, - resolved: false, - // only show emails that are more than 3 days overdue - sentAt: { lt: subHours(new Date(), 24 * 3) }, - }, - orderBy: { sentAt: "desc" }, - take: 20, - distinct: ["threadId"], - }), - // needs action - currently not used - // prisma.threadTracker.findMany({ - // where: { - // userId: user.id, - // type: ThreadTrackerType.NEEDS_ACTION, - // resolved: false, - // }, - // orderBy: { sentAt: "desc" }, - // take: 20, - // distinct: ["threadId"], - // }), - ]); + }, + select: { id: true }, + }); + + // Get counts and recent threads for each type + const [counts, needsReply, awaitingReply, coldExecutedRules] = + await Promise.all([ + // total count + // NOTE: should really be distinct by threadId. this will cause a mismatch in some cases + prisma.threadTracker.groupBy({ + by: ["type"], + where: { + emailAccountId, + resolved: false, + }, + _count: true, + }), + // needs reply + prisma.threadTracker.findMany({ + where: { + emailAccountId, + type: ThreadTrackerType.NEEDS_REPLY, + resolved: false, + }, + orderBy: { sentAt: "desc" }, + take: 20, + distinct: ["threadId"], + }), + // awaiting reply + prisma.threadTracker.findMany({ + where: { + emailAccountId, + type: ThreadTrackerType.AWAITING, + resolved: false, + // only show emails that are more than 3 days overdue + sentAt: { lt: subHours(new Date(), 24 * 3) }, + }, + orderBy: { sentAt: "desc" }, + take: 20, + distinct: ["threadId"], + }), + // cold emails + coldEmailRule + ? prisma.executedRule.findMany({ + where: { + ruleId: coldEmailRule.id, + automated: true, + createdAt: { gt: cutOffDate }, + }, + select: { + messageId: true, + createdAt: true, + }, + }) + : Promise.resolve([]), + ]); const typeCounts = Object.fromEntries( counts.map((count) => [count.type, count._count]), ); - const coldEmailers = emailAccount.coldEmails.map((e) => ({ - from: e.fromEmail, - subject: "", - sentAt: e.createdAt, - })); - // get messages const messageIds = [ ...needsReply.map((m) => m.messageId), ...awaitingReply.map((m) => m.messageId), - // ...needsAction.map((m) => m.messageId), + ...coldExecutedRules.map((r) => r.messageId), ]; logger.info("Getting messages", { @@ -237,14 +230,14 @@ async function sendEmail({ }; }); - // const recentNeedsAction = needsAction.map((t) => { - // const message = messageMap[t.messageId]; - // return { - // from: message?.headers.from || "Unknown", - // subject: decodeSnippet(message?.snippet) || "", - // sentAt: t.sentAt, - // }; - // }); + const coldEmailers = coldExecutedRules.map((r) => { + const message = messageMap[r.messageId]; + return { + from: message?.headers.from || "Unknown", + subject: decodeSnippet(message?.snippet) || "", + sentAt: r.createdAt, + }; + }); const shouldSendEmail = !!( coldEmailers.length || @@ -281,7 +274,6 @@ async function sendEmail({ needsActionCount: typeCounts[ThreadTrackerType.NEEDS_ACTION], needsReply: recentNeedsReply, awaitingReply: recentAwaitingReply, - // needsAction: recentNeedsAction, unsubscribeToken: token, }, }); diff --git a/apps/web/app/api/user/cold-email/route.ts b/apps/web/app/api/user/cold-email/route.ts index d7047e9711..3f517fd500 100644 --- a/apps/web/app/api/user/cold-email/route.ts +++ b/apps/web/app/api/user/cold-email/route.ts @@ -1,7 +1,11 @@ import { NextResponse } from "next/server"; import prisma from "@/utils/prisma"; import { withEmailAccount } from "@/utils/middleware"; -import { ColdEmailStatus } from "@/generated/prisma/enums"; +import { + ColdEmailStatus, + GroupItemType, + SystemType, +} from "@/generated/prisma/enums"; const LIMIT = 50; @@ -14,30 +18,54 @@ async function getColdEmails( }: { emailAccountId: string; status: ColdEmailStatus }, page: number, ) { + const coldEmailRule = await prisma.rule.findUnique({ + where: { + emailAccountId_systemType: { + emailAccountId, + systemType: SystemType.COLD_EMAIL, + }, + }, + select: { id: true, groupId: true }, + }); + + if (!coldEmailRule?.groupId) { + return { coldEmails: [], totalPages: 0 }; + } + const where = { - emailAccountId, - status, + groupId: coldEmailRule.groupId, + type: GroupItemType.FROM, + exclude: status === ColdEmailStatus.USER_REJECTED_COLD, }; - const [coldEmails, count] = await Promise.all([ - prisma.coldEmail.findMany({ + const [groupItems, count] = await Promise.all([ + prisma.groupItem.findMany({ where, take: LIMIT, skip: (page - 1) * LIMIT, orderBy: { createdAt: "desc" }, select: { id: true, - fromEmail: true, - status: true, + value: true, createdAt: true, reason: true, threadId: true, messageId: true, }, }), - prisma.coldEmail.count({ where }), + prisma.groupItem.count({ where }), ]); + const coldEmails = groupItems.map((item) => ({ + id: item.id, + fromEmail: item.value, + status: status, + createdAt: item.createdAt, + reason: item.reason, + threadId: item.threadId, + messageId: item.messageId, + })); + return { coldEmails, totalPages: Math.ceil(count / LIMIT) }; } diff --git a/apps/web/components/EmailMessageCell.tsx b/apps/web/components/EmailMessageCell.tsx index 84aebc7e66..f3b7ac29b9 100644 --- a/apps/web/components/EmailMessageCell.tsx +++ b/apps/web/components/EmailMessageCell.tsx @@ -144,7 +144,8 @@ export function EmailMessageCellWithData({ }) { const { data, isLoading, error } = useThread({ id: threadId }); - const firstMessage = data?.thread.messages?.[0]; + const firstMessage = data?.thread?.messages?.[0]; + const emailNotFound = !isLoading && !error && !firstMessage; return ( ); } diff --git a/apps/web/prisma/migrations/20260103000000_migrate_cold_emails_to_group_items/migration.sql b/apps/web/prisma/migrations/20260103000000_migrate_cold_emails_to_group_items/migration.sql new file mode 100644 index 0000000000..e977ea830d --- /dev/null +++ b/apps/web/prisma/migrations/20260103000000_migrate_cold_emails_to_group_items/migration.sql @@ -0,0 +1,97 @@ +-- CreateEnum +CREATE TYPE "GroupItemSource" AS ENUM ('AI', 'USER'); + +-- AlterTable +ALTER TABLE "GroupItem" ADD COLUMN "messageId" TEXT, +ADD COLUMN "reason" TEXT, +ADD COLUMN "source" "GroupItemSource", +ADD COLUMN "threadId" TEXT; + +-- Migrate ColdEmail data to GroupItem +-- This migration moves historical cold email data from the deprecated ColdEmail table +-- to the unified GroupItem table (learned patterns system) + +-- Step 1: Create Groups for Cold Email rules that don't have one yet +-- If a group with the same name and emailAccountId already exists, reuse it; otherwise create a new group +DO $$ +DECLARE + rule_record RECORD; + new_group_id TEXT; + group_name TEXT; +BEGIN + FOR rule_record IN + SELECT r.id, r.name, r."emailAccountId" + FROM "Rule" r + WHERE r."systemType" = 'COLD_EMAIL' + AND r."groupId" IS NULL + LOOP + new_group_id := gen_random_uuid()::TEXT; + group_name := rule_record.name; + + -- Check if a group with this name already exists + IF EXISTS ( + SELECT 1 FROM "Group" g + WHERE g.name = group_name AND g."emailAccountId" = rule_record."emailAccountId" + ) THEN + -- Use existing group if it exists + UPDATE "Rule" + SET "groupId" = ( + SELECT id FROM "Group" g + WHERE g.name = group_name AND g."emailAccountId" = rule_record."emailAccountId" + LIMIT 1 + ) + WHERE id = rule_record.id; + ELSE + -- Create new group + INSERT INTO "Group" (id, "createdAt", "updatedAt", name, "emailAccountId") + VALUES (new_group_id, NOW(), NOW(), group_name, rule_record."emailAccountId"); + + UPDATE "Rule" SET "groupId" = new_group_id WHERE id = rule_record.id; + END IF; + END LOOP; +END $$; + +-- Step 2: Migrate ColdEmail records to GroupItem +INSERT INTO "GroupItem" ( + id, + "createdAt", + "updatedAt", + "groupId", + type, + value, + exclude, + reason, + "threadId", + "messageId", + source +) +SELECT + gen_random_uuid() as id, + ce."createdAt", + ce."updatedAt", + r."groupId", + 'FROM'::"GroupItemType" as type, + ce."fromEmail" as value, + CASE + WHEN ce.status = 'USER_REJECTED_COLD' THEN true + ELSE false + END as exclude, + ce.reason, + ce."threadId", + ce."messageId", + CASE + WHEN ce.status = 'USER_REJECTED_COLD' THEN 'USER'::"GroupItemSource" + ELSE 'AI'::"GroupItemSource" + END as source +FROM "ColdEmail" ce +JOIN "Rule" r ON r."emailAccountId" = ce."emailAccountId" AND r."systemType" = 'COLD_EMAIL' +WHERE r."groupId" IS NOT NULL + AND ce."fromEmail" IS NOT NULL + -- Avoid duplicates: only insert if this pattern doesn't already exist + AND NOT EXISTS ( + SELECT 1 FROM "GroupItem" gi + WHERE gi."groupId" = r."groupId" + AND gi.type = 'FROM' + AND gi.value = ce."fromEmail" + ); + diff --git a/apps/web/prisma/migrations/20260104000000_add_label_removed_to_group_item_source/migration.sql b/apps/web/prisma/migrations/20260104000000_add_label_removed_to_group_item_source/migration.sql new file mode 100644 index 0000000000..e9a11409f5 --- /dev/null +++ b/apps/web/prisma/migrations/20260104000000_add_label_removed_to_group_item_source/migration.sql @@ -0,0 +1,3 @@ +-- AlterEnum +ALTER TYPE "GroupItemSource" ADD VALUE 'LABEL_REMOVED'; + diff --git a/apps/web/prisma/schema.prisma b/apps/web/prisma/schema.prisma index 3cd1d88e14..846623fb59 100644 --- a/apps/web/prisma/schema.prisma +++ b/apps/web/prisma/schema.prisma @@ -156,7 +156,7 @@ model EmailAccount { rules Rule[] executedRules ExecutedRule[] newsletters Newsletter[] - coldEmails ColdEmail[] + coldEmails ColdEmail[] // @deprecated - kept for backward compatibility during migration groups Group[] categories Category[] threadTrackers ThreadTracker[] @@ -277,12 +277,12 @@ model DigestItem { digest Digest @relation(fields: [digestId], references: [id], onDelete: Cascade) actionId String? action ExecutedAction? @relation(fields: [actionId], references: [id], onDelete: Cascade) - coldEmailId String? - coldEmail ColdEmail? @relation(fields: [coldEmailId], references: [id]) + coldEmailId String? // @deprecated + coldEmail ColdEmail? @relation(fields: [coldEmailId], references: [id]) // @deprecated @@unique([digestId, threadId, messageId]) @@index([actionId]) - @@index([coldEmailId]) + @@index([coldEmailId]) // @deprecated } model Schedule { @@ -620,6 +620,12 @@ model GroupItem { value String // eg "@gmail.com", "matt@gmail.com", "Receipt from" exclude Boolean @default(false) // Whether this pattern should be excluded rather than included + // Optional context for why/how this pattern was learned. + reason String? + threadId String? + messageId String? + source GroupItemSource? // provides value for UI/audit. + @@unique([groupId, type, value]) } @@ -664,6 +670,9 @@ model Newsletter { @@index([categoryId]) } +// @deprecated - ColdEmail data is being migrated to GroupItem (learned patterns). +// This model is kept for backward compatibility during the migration period. +// Once all users have run the migration, this model can be deleted. model ColdEmail { id String @id @default(cuid()) createdAt DateTime @default(now()) @@ -1189,3 +1198,9 @@ enum MeetingBriefingStatus { FAILED SKIPPED } + +enum GroupItemSource { + AI + USER + LABEL_REMOVED +} diff --git a/apps/web/utils/actions/cold-email.ts b/apps/web/utils/actions/cold-email.ts index ef633f29dc..6f877dadde 100644 --- a/apps/web/utils/actions/cold-email.ts +++ b/apps/web/utils/actions/cold-email.ts @@ -1,7 +1,7 @@ "use server"; import prisma from "@/utils/prisma"; -import { ColdEmailStatus, SystemType } from "@/generated/prisma/enums"; +import { GroupItemSource } from "@/generated/prisma/enums"; import { emailToContent } from "@/utils/mail"; import { isColdEmail } from "@/utils/cold-email/is-cold-email"; import { @@ -13,8 +13,8 @@ import { SafeError } from "@/utils/error"; import { createEmailProvider } from "@/utils/email/provider"; import type { EmailProvider } from "@/utils/email/types"; import { getColdEmailRule } from "@/utils/cold-email/cold-email-rule"; -import { getRuleLabel } from "@/utils/rule/consts"; import { internalDateToDate } from "@/utils/date"; +import { saveLearnedPattern } from "@/utils/rule/learned-patterns"; export const markNotColdEmailAction = actionClient .metadata({ name: "markNotColdEmail" }) @@ -24,74 +24,52 @@ export const markNotColdEmailAction = actionClient ctx: { emailAccountId, provider, logger }, parsedInput: { sender }, }) => { - const emailProvider = await createEmailProvider({ - emailAccountId, - provider, - logger, - }); + const [emailProvider, coldEmailRule] = await Promise.all([ + createEmailProvider({ + emailAccountId, + provider, + logger, + }), + getColdEmailRule(emailAccountId), + ]); + + if (!coldEmailRule) { + throw new SafeError("Cold email rule not found"); + } await Promise.all([ - prisma.coldEmail.update({ - where: { - emailAccountId_fromEmail: { - emailAccountId, - fromEmail: sender, - }, - }, - data: { - status: ColdEmailStatus.USER_REJECTED_COLD, - }, + // Mark as excluded so AI doesn't match it again + saveLearnedPattern({ + emailAccountId, + from: sender, + ruleId: coldEmailRule.id, + exclude: true, + logger, + source: GroupItemSource.USER, }), - removeColdEmailLabelFromSender(emailAccountId, emailProvider, sender), + removeColdEmailLabelFromSender(emailProvider, sender, coldEmailRule), ]); }, ); -/** - * Helper function to get threads from a specific sender using the email provider - */ -async function getThreadsFromSender( - emailProvider: EmailProvider, - sender: string, - labelId?: string, -): Promise<{ id: string }[]> { - const { threads } = await emailProvider.getThreadsWithQuery({ - query: { - fromEmail: sender, - labelId, - }, - maxResults: 100, - }); - - return threads.map((thread) => ({ id: thread.id })); -} - async function removeColdEmailLabelFromSender( - emailAccountId: string, emailProvider: EmailProvider, sender: string, + coldEmailRule: { actions: { labelId: string | null }[] }, ) { - // 1. find cold email label - // 2. find emails from sender - // 3. remove cold email label from emails - - const coldEmailRule = await getColdEmailRule(emailAccountId); - if (!coldEmailRule) return; + const labelIds = coldEmailRule.actions + .map((action) => action.labelId) + .filter((id): id is string => Boolean(id)); - const labels = await emailProvider.getLabels(); + if (labelIds.length === 0) return; - // NOTE: this doesn't work completely if the user set 2 labels: - const label = - labels.find((label) => label.id === coldEmailRule.actions?.[0]?.labelId) || - labels.find((label) => label.name === getRuleLabel(SystemType.COLD_EMAIL)); - - if (!label?.id) return; - - const threads = await getThreadsFromSender(emailProvider, sender, label.id); + const { threads } = await emailProvider.getThreadsWithQuery({ + query: { fromEmail: sender }, + maxResults: 100, + }); for (const thread of threads) { - if (!thread.id) continue; - await emailProvider.removeThreadLabel(thread.id, label.id); + await emailProvider.removeThreadLabels(thread.id, labelIds); } } diff --git a/apps/web/utils/ai/choose-rule/match-rules.test.ts b/apps/web/utils/ai/choose-rule/match-rules.test.ts index bd7076ce7e..e8c6bdb2f2 100644 --- a/apps/web/utils/ai/choose-rule/match-rules.test.ts +++ b/apps/web/utils/ai/choose-rule/match-rules.test.ts @@ -1653,34 +1653,121 @@ describe("filterToReplyPreset", () => { }); function getRule(overrides: Partial = {}): RuleWithActions { + const { + id = "r123", + createdAt = new Date(), + updatedAt = new Date(), + name = "Rule Name", + enabled = true, + automate = true, + runOnThreads = true, + emailAccountId = "emailAccountId", + conditionalOperator = LogicalOperator.AND, + instructions = null, + groupId = null, + from = null, + to = null, + subject = null, + body = null, + categoryFilterType = null, + systemType = null, + promptText = null, + actions = [], + } = overrides; + return { - id: "r123", - userId: "userId", - runOnThreads: true, - conditionalOperator: LogicalOperator.AND, - type: null, - systemType: null, - ...overrides, - } as RuleWithActions; + id, + createdAt, + updatedAt, + name, + enabled, + automate, + runOnThreads, + emailAccountId, + conditionalOperator, + instructions, + groupId, + from, + to, + subject, + body, + categoryFilterType, + systemType, + promptText, + actions, + }; } function getHeaders( overrides: Partial = {}, ): ParsedMessageHeaders { + const { + subject = "Subject", + from = "from@example.com", + to = "to@example.com", + cc, + bcc, + date = new Date().toISOString(), + "message-id": messageId, + "reply-to": replyTo, + "in-reply-to": inReplyTo, + references, + "list-unsubscribe": listUnsubscribe, + } = overrides; + return { - ...overrides, - } as ParsedMessageHeaders; + subject, + from, + to, + cc, + bcc, + date, + "message-id": messageId, + "reply-to": replyTo, + "in-reply-to": inReplyTo, + references, + "list-unsubscribe": listUnsubscribe, + }; } function getMessage(overrides: Partial = {}): ParsedMessage { - const message = { - id: "m1", - threadId: "m1", - headers: getHeaders(), - ...overrides, - }; + const { + id = "m1", + threadId = "m1", + labelIds = [], + snippet = "snippet", + historyId = "h1", + attachments = [], + inline = [], + headers = getHeaders(), + textPlain = "textPlain", + textHtml = "textHtml", + subject = "subject", + date = new Date().toISOString(), + conversationIndex = null, + internalDate = null, + bodyContentType, + rawRecipients, + } = overrides; - return message as ParsedMessage; + return { + id, + threadId, + labelIds, + snippet, + historyId, + attachments, + inline, + headers, + textPlain, + textHtml, + subject, + date, + conversationIndex, + internalDate, + bodyContentType, + rawRecipients, + }; } function getGroup( @@ -1688,29 +1775,56 @@ function getGroup( Prisma.GroupGetPayload<{ include: { items: true; rule: true } }> > = {}, ): Prisma.GroupGetPayload<{ include: { items: true; rule: true } }> { + const { + id = "group1", + name = "group", + createdAt = new Date(), + updatedAt = new Date(), + emailAccountId = "emailAccountId", + prompt = null, + items = [], + rule = null, + } = overrides; + return { - id: "group1", - name: "group", - createdAt: new Date(), - updatedAt: new Date(), - emailAccountId: "emailAccountId", - prompt: null, - items: [], - rule: null, - ...overrides, + id, + name, + createdAt, + updatedAt, + emailAccountId, + prompt, + items, + rule, }; } function getGroupItem(overrides: Partial = {}): GroupItem { + const { + id = "groupItem1", + createdAt = new Date(), + updatedAt = new Date(), + groupId = "groupId", + type = GroupItemType.FROM, + value = "test@example.com", + exclude = false, + reason = null, + threadId = null, + messageId = null, + source = null, + } = overrides; + return { - id: "groupItem1", - createdAt: new Date(), - updatedAt: new Date(), - groupId: "groupId", - type: GroupItemType.FROM, - value: "test@example.com", - exclude: false, - ...overrides, + id, + createdAt, + updatedAt, + groupId, + type, + value, + exclude, + reason, + threadId, + messageId, + source, }; } diff --git a/apps/web/utils/ai/choose-rule/match-rules.ts b/apps/web/utils/ai/choose-rule/match-rules.ts index 52dcf1f493..36c902478a 100644 --- a/apps/web/utils/ai/choose-rule/match-rules.ts +++ b/apps/web/utils/ai/choose-rule/match-rules.ts @@ -84,7 +84,7 @@ export async function findMatchingRules({ matchReasons: [{ type: ConditionType.AI }], }, ], - reasoning: coldEmailResult.reason, + reasoning: coldEmailResult.aiReason || coldEmailResult.reason, }; } } diff --git a/apps/web/utils/ai/choose-rule/run-rules.test.ts b/apps/web/utils/ai/choose-rule/run-rules.test.ts index 67e30b5ec3..0cbc73573e 100644 --- a/apps/web/utils/ai/choose-rule/run-rules.test.ts +++ b/apps/web/utils/ai/choose-rule/run-rules.test.ts @@ -38,8 +38,9 @@ vi.mock("@/utils/ai/choose-rule/execute", () => ({ vi.mock("@/utils/reply-tracker/label-helpers", () => ({ removeConflictingThreadStatusLabels: vi.fn(), })); -vi.mock("@/utils/cold-email/is-cold-email", () => ({ - saveColdEmail: vi.fn(), +vi.mock("@/utils/rule/learned-patterns", () => ({ + saveLearnedPattern: vi.fn(), + saveLearnedPatterns: vi.fn(), })); vi.mock("@/utils/scheduled-actions/scheduler", () => ({ scheduleDelayedActions: vi.fn(), diff --git a/apps/web/utils/ai/choose-rule/run-rules.ts b/apps/web/utils/ai/choose-rule/run-rules.ts index e3dba35139..48aed7fd79 100644 --- a/apps/web/utils/ai/choose-rule/run-rules.ts +++ b/apps/web/utils/ai/choose-rule/run-rules.ts @@ -4,6 +4,7 @@ import type { EmailAccountWithAI } from "@/utils/llms/types"; import { ActionType, ExecutedRuleStatus, + GroupItemSource, SystemType, } from "@/generated/prisma/enums"; import type { Rule } from "@/generated/prisma/client"; @@ -34,7 +35,7 @@ import { updateThreadTrackers, } from "@/utils/reply-tracker/handle-conversation-status"; import { removeConflictingThreadStatusLabels } from "@/utils/reply-tracker/label-helpers"; -import { saveColdEmail } from "@/utils/cold-email/is-cold-email"; +import { saveLearnedPattern } from "@/utils/rule/learned-patterns"; import { internalDateToDate } from "@/utils/date"; import { ConditionType } from "@/utils/config"; import type { Logger } from "@/utils/logger"; @@ -340,14 +341,17 @@ async function executeMatchedRule( }); if (rule.systemType === SystemType.COLD_EMAIL) { - await saveColdEmail({ - email: { - id: message.id, - threadId: message.threadId, - from: message.headers.from, - }, - emailAccount, - aiReason: reason ?? null, + const from = + extractEmailAddress(message.headers.from) || message.headers.from; + await saveLearnedPattern({ + emailAccountId: emailAccount.id, + from, + ruleId: rule.id, + logger, + reason, + messageId: message.id, + threadId: message.threadId, + source: GroupItemSource.AI, }); } diff --git a/apps/web/utils/ai/choose-rule/types.ts b/apps/web/utils/ai/choose-rule/types.ts index 9be6ad3e11..da3aad477e 100644 --- a/apps/web/utils/ai/choose-rule/types.ts +++ b/apps/web/utils/ai/choose-rule/types.ts @@ -41,7 +41,7 @@ export type MatchingRuleResult = { /** * Serializable version of MatchReason for database storage */ -export type SerializedMatchReason = +type SerializedMatchReason = | { type: "STATIC" } | { type: "LEARNED_PATTERN"; diff --git a/apps/web/utils/cold-email/cold-email-rule.ts b/apps/web/utils/cold-email/cold-email-rule.ts index c228396c8c..501168017b 100644 --- a/apps/web/utils/cold-email/cold-email-rule.ts +++ b/apps/web/utils/cold-email/cold-email-rule.ts @@ -17,6 +17,7 @@ export async function getColdEmailRule(emailAccountId: string) { id: true, enabled: true, instructions: true, + groupId: true, actions: { select: { type: true, diff --git a/apps/web/utils/cold-email/is-cold-email.test.ts b/apps/web/utils/cold-email/is-cold-email.test.ts index a7d3931304..95dc0a67f3 100644 --- a/apps/web/utils/cold-email/is-cold-email.test.ts +++ b/apps/web/utils/cold-email/is-cold-email.test.ts @@ -1,20 +1,16 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; -import { isColdEmail, saveColdEmail } from "./is-cold-email"; +import { isColdEmail } from "./is-cold-email"; import { getEmailAccount } from "@/__tests__/helpers"; import type { EmailForLLM } from "@/utils/types"; -import { ColdEmailStatus } from "@/generated/prisma/enums"; -import prisma from "@/utils/prisma"; +import { GroupItemType } from "@/generated/prisma/enums"; +import prisma from "@/utils/__mocks__/prisma"; import { extractEmailAddress } from "@/utils/email"; vi.mock("server-only", () => ({})); +vi.mock("@/utils/prisma"); -vi.mock("@/utils/prisma", () => ({ - default: { - coldEmail: { - findUnique: vi.fn(), - upsert: vi.fn(), - }, - }, +vi.mock("./cold-email-rule", () => ({ + getColdEmailRule: vi.fn(), })); vi.mock("@/utils/email", async () => { @@ -41,34 +37,15 @@ describe("isColdEmail", () => { it("should recognize a known cold email sender even when from field format differs", async () => { const emailAccount = getEmailAccount({ id: "test-account-id" }); const normalizedEmail = "cold.sender@example.com"; + const groupId = "test-group-id"; - // First, simulate saving a cold email with normalized email address - // This is what saveColdEmail does - it extracts just the email address - vi.mocked(prisma.coldEmail.upsert).mockResolvedValue({ - id: "cold-email-id", - emailAccountId: emailAccount.id, - fromEmail: normalizedEmail, - status: ColdEmailStatus.AI_LABELED_COLD, - reason: "Test reason", - messageId: "msg1", - threadId: "thread1", - createdAt: new Date(), - updatedAt: new Date(), - }); + // Mock groupItem lookup + vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({ + id: "group-item-id", + exclude: false, + } as any); - await saveColdEmail({ - email: { - from: normalizedEmail, - id: "msg1", - threadId: "thread1", - }, - emailAccount, - aiReason: "Test reason", - }); - - // Now simulate a second email from the same sender but with a different format - // This is the bug scenario: the from field has a display name - const secondEmail: EmailForLLM = { + const email: EmailForLLM = { id: "msg2", from: `"Cold Sender" <${normalizedEmail}>`, to: emailAccount.email, @@ -77,93 +54,91 @@ describe("isColdEmail", () => { date: new Date(), }; - // Mock Prisma to return the cold email record when queried with normalized email - vi.mocked(prisma.coldEmail.findUnique).mockResolvedValue({ - id: "cold-email-id", - emailAccountId: emailAccount.id, - fromEmail: normalizedEmail, - status: ColdEmailStatus.AI_LABELED_COLD, - reason: "Test reason", - messageId: "msg1", - threadId: "thread1", - createdAt: new Date(), - updatedAt: new Date(), - }); - const result = await isColdEmail({ - email: secondEmail, + email, emailAccount, provider: mockProvider as never, - coldEmailRule: null, + coldEmailRule: { instructions: "test instructions", groupId }, }); - // This test should pass after the fix - the sender should be recognized as cold expect(result.isColdEmail).toBe(true); expect(result.reason).toBe("ai-already-labeled"); - // Verify that findUnique was called with the normalized email address - expect(prisma.coldEmail.findUnique).toHaveBeenCalledWith({ + // Verify that findFirst was called with the normalized email address + expect(prisma.groupItem.findFirst).toHaveBeenCalledWith({ where: { - emailAccountId_fromEmail: { - emailAccountId: emailAccount.id, - fromEmail: normalizedEmail, - }, - status: ColdEmailStatus.AI_LABELED_COLD, + groupId, + type: GroupItemType.FROM, + value: normalizedEmail, + }, + select: { exclude: true }, + }); + }); + + it("should return excluded when sender is explicitly excluded from cold email blocker", async () => { + const emailAccount = getEmailAccount({ id: "test-account-id" }); + const normalizedEmail = "excluded.sender@example.com"; + const groupId = "test-group-id"; + + // Mock groupItem lookup with exclude: true + vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({ + id: "group-item-id", + exclude: true, + } as any); + + const email: EmailForLLM = { + id: "msg-excluded", + from: `"Excluded Sender" <${normalizedEmail}>`, + to: emailAccount.email, + subject: "Not a cold email", + content: "This sender was explicitly excluded", + date: new Date(), + }; + + const result = await isColdEmail({ + email, + emailAccount, + provider: mockProvider as never, + coldEmailRule: { instructions: "test instructions", groupId }, + }); + + expect(result.isColdEmail).toBe(false); + expect(result.reason).toBe("excluded"); + + expect(prisma.groupItem.findFirst).toHaveBeenCalledWith({ + where: { + groupId, + type: GroupItemType.FROM, + value: normalizedEmail, }, - select: { id: true }, + select: { exclude: true }, }); }); it("should handle various email formats consistently", async () => { const emailAccount = getEmailAccount({ id: "test-account-id" }); const normalizedEmail = "sender@example.com"; + const groupId = "test-group-id"; + + vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({ + id: "group-item-id", + exclude: false, + } as any); - // Test different from field formats that should all resolve to the same normalized email const emailFormats = [ normalizedEmail, `<${normalizedEmail}>`, `"Display Name" <${normalizedEmail}>`, `Display Name <${normalizedEmail}>`, - ` ${normalizedEmail} `, // with spaces + ` ${normalizedEmail} `, ]; - // Mock Prisma to return cold email for normalized email - vi.mocked(prisma.coldEmail.findUnique).mockImplementation( - (args) => - new Promise((resolve) => { - const where = args?.where as - | { - emailAccountId_fromEmail: { - emailAccountId: string; - fromEmail: string; - }; - status: ColdEmailStatus; - } - | undefined; - - if ( - where?.emailAccountId_fromEmail.fromEmail === normalizedEmail && - where.status === ColdEmailStatus.AI_LABELED_COLD - ) { - resolve({ - id: "cold-email-id", - emailAccountId: emailAccount.id, - fromEmail: normalizedEmail, - status: ColdEmailStatus.AI_LABELED_COLD, - reason: "Test reason", - messageId: "msg1", - threadId: "thread1", - createdAt: new Date(), - updatedAt: new Date(), - } as never); - } else { - resolve(null as never); - } - }) as never, - ); - for (const fromFormat of emailFormats) { vi.clearAllMocks(); + vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({ + id: "group-item-id", + exclude: false, + } as any); const email: EmailForLLM = { id: "msg-test", @@ -178,23 +153,21 @@ describe("isColdEmail", () => { email, emailAccount, provider: mockProvider as never, - coldEmailRule: null, + coldEmailRule: { instructions: "test instructions", groupId }, }); expect(result.isColdEmail).toBe(true); expect(result.reason).toBe("ai-already-labeled"); - // Verify extractEmailAddress was used to normalize - const expectedNormalized = extractEmailAddress(fromFormat); - expect(prisma.coldEmail.findUnique).toHaveBeenCalledWith({ + const expectedNormalized = + extractEmailAddress(fromFormat) || fromFormat.trim(); + expect(prisma.groupItem.findFirst).toHaveBeenCalledWith({ where: { - emailAccountId_fromEmail: { - emailAccountId: emailAccount.id, - fromEmail: expectedNormalized, - }, - status: ColdEmailStatus.AI_LABELED_COLD, + groupId, + type: GroupItemType.FROM, + value: expectedNormalized, }, - select: { id: true }, + select: { exclude: true }, }); } }); diff --git a/apps/web/utils/cold-email/is-cold-email.ts b/apps/web/utils/cold-email/is-cold-email.ts index 5988da8edc..89e976ed9f 100644 --- a/apps/web/utils/cold-email/is-cold-email.ts +++ b/apps/web/utils/cold-email/is-cold-email.ts @@ -1,7 +1,7 @@ import { z } from "zod"; import type { EmailAccountWithAI } from "@/utils/llms/types"; -import type { ColdEmail, Rule } from "@/generated/prisma/client"; -import { ColdEmailStatus } from "@/generated/prisma/enums"; +import type { Rule } from "@/generated/prisma/client"; +import { GroupItemType } from "@/generated/prisma/enums"; import prisma from "@/utils/prisma"; import { DEFAULT_COLD_EMAIL_PROMPT } from "@/utils/cold-email/prompt"; import { stringifyEmail } from "@/utils/stringify-email"; @@ -14,7 +14,11 @@ import { extractEmailAddress } from "@/utils/email"; export const COLD_EMAIL_FOLDER_NAME = "Cold Emails"; -type ColdEmailBlockerReason = "hasPreviousEmail" | "ai" | "ai-already-labeled"; +type ColdEmailBlockerReason = + | "hasPreviousEmail" + | "ai" + | "ai-already-labeled" + | "excluded"; export async function isColdEmail({ email, @@ -27,7 +31,7 @@ export async function isColdEmail({ emailAccount: EmailAccountWithAI; provider: EmailProvider; modelType?: ModelType; - coldEmailRule: Pick | null; + coldEmailRule: Pick | null; }): Promise<{ isColdEmail: boolean; reason: ColdEmailBlockerReason; @@ -43,16 +47,31 @@ export async function isColdEmail({ logger.info("Checking is cold email"); // Check if we marked it as a cold email already - const isColdEmailer = await isKnownColdEmailSender({ - from: email.from, - emailAccountId: emailAccount.id, - }); + const groupId = coldEmailRule?.groupId; + let patternMatch: { exclude: boolean } | null = null; + + if (groupId) { + const normalizedFrom = extractEmailAddress(email.from) || email.from; + patternMatch = await prisma.groupItem.findFirst({ + where: { + groupId, + type: GroupItemType.FROM, + value: normalizedFrom, + }, + select: { exclude: true }, + }); + } - if (isColdEmailer) { - logger.info("Known cold email sender", { + if (patternMatch && !patternMatch.exclude) { + logger.info("Known cold email sender", { from: email.from }); + return { isColdEmail: true, reason: "ai-already-labeled" }; + } + + if (patternMatch?.exclude) { + logger.info("Sender explicitly excluded from cold email blocker", { from: email.from, }); - return { isColdEmail: true, reason: "ai-already-labeled" }; + return { isColdEmail: false, reason: "excluded" }; } const hasPreviousEmail = @@ -88,28 +107,6 @@ export async function isColdEmail({ }; } -async function isKnownColdEmailSender({ - from, - emailAccountId, -}: { - from: string; - emailAccountId: string; -}) { - const normalizedFrom = extractEmailAddress(from) || from; - - const coldEmail = await prisma.coldEmail.findUnique({ - where: { - emailAccountId_fromEmail: { - emailAccountId, - fromEmail: normalizedFrom, - }, - status: ColdEmailStatus.AI_LABELED_COLD, - }, - select: { id: true }, - }); - return !!coldEmail; -} - async function aiIsColdEmail( email: EmailForLLM, emailAccount: EmailAccountWithAI, @@ -161,33 +158,3 @@ ${stringifyEmail(email, 500)} return response.object; } - -export async function saveColdEmail({ - email, - emailAccount, - aiReason, -}: { - email: { from: string; id: string; threadId: string }; - emailAccount: EmailAccountWithAI; - aiReason: string | null; -}): Promise { - const from = extractEmailAddress(email.from) || email.from; - - return await prisma.coldEmail.upsert({ - where: { - emailAccountId_fromEmail: { - emailAccountId: emailAccount.id, - fromEmail: from, - }, - }, - update: { status: ColdEmailStatus.AI_LABELED_COLD }, - create: { - status: ColdEmailStatus.AI_LABELED_COLD, - fromEmail: from, - emailAccountId: emailAccount.id, - reason: aiReason, - messageId: email.id, - threadId: email.threadId, - }, - }); -} diff --git a/apps/web/utils/rule/consts.ts b/apps/web/utils/rule/consts.ts index 7167a696e7..50658248d6 100644 --- a/apps/web/utils/rule/consts.ts +++ b/apps/web/utils/rule/consts.ts @@ -14,6 +14,7 @@ const ruleConfig: Record< categoryAction: "label" | "label_archive" | "move_folder"; categoryActionMicrosoft?: "move_folder"; tooltipText: string; + shouldLearn: boolean; } > = { [SystemType.TO_REPLY]: { @@ -25,6 +26,7 @@ const ruleConfig: Record< categoryAction: "label", tooltipText: "Emails you need to reply to and those where you're awaiting a reply. The label will update automatically as the conversation progresses", + shouldLearn: false, }, [SystemType.FYI]: { name: "FYI", @@ -33,6 +35,7 @@ const ruleConfig: Record< runOnThreads: true, categoryAction: "label", tooltipText: "", + shouldLearn: false, }, [SystemType.AWAITING_REPLY]: { name: "Awaiting Reply", @@ -41,6 +44,7 @@ const ruleConfig: Record< runOnThreads: true, categoryAction: "label", tooltipText: "", + shouldLearn: false, }, [SystemType.ACTIONED]: { name: "Actioned", @@ -49,6 +53,7 @@ const ruleConfig: Record< runOnThreads: true, categoryAction: "label", tooltipText: "", + shouldLearn: false, }, [SystemType.NEWSLETTER]: { name: "Newsletter", @@ -59,6 +64,7 @@ const ruleConfig: Record< categoryAction: "label", categoryActionMicrosoft: "move_folder", tooltipText: "Newsletters, blogs, and publications", + shouldLearn: true, }, [SystemType.MARKETING]: { name: "Marketing", @@ -69,6 +75,7 @@ const ruleConfig: Record< categoryAction: "label_archive", categoryActionMicrosoft: "move_folder", tooltipText: "Promotional emails about sales and offers", + shouldLearn: true, }, [SystemType.CALENDAR]: { name: "Calendar", @@ -78,6 +85,7 @@ const ruleConfig: Record< runOnThreads: false, categoryAction: "label", tooltipText: "Events, appointments, and reminders", + shouldLearn: true, }, [SystemType.RECEIPT]: { name: "Receipt", @@ -88,6 +96,7 @@ const ruleConfig: Record< categoryAction: "label", categoryActionMicrosoft: "move_folder", tooltipText: "Invoices, receipts, and payments", + shouldLearn: true, }, [SystemType.NOTIFICATION]: { name: "Notification", @@ -97,6 +106,7 @@ const ruleConfig: Record< categoryAction: "label", categoryActionMicrosoft: "move_folder", tooltipText: "Alerts, status updates, and system messages", + shouldLearn: true, }, [SystemType.COLD_EMAIL]: { name: "Cold Email", @@ -107,6 +117,7 @@ const ruleConfig: Record< categoryActionMicrosoft: "move_folder", tooltipText: "Unsolicited sales pitches and cold emails. We'll never block someone that's emailed you before", + shouldLearn: true, }, }; @@ -124,6 +135,10 @@ export function getRuleLabel(systemType: SystemType) { return getRuleConfig(systemType).label; } +export function shouldLearnFromLabelRemoval(systemType: SystemType): boolean { + return getRuleConfig(systemType).shouldLearn; +} + export function getCategoryAction(systemType: SystemType, provider: string) { const config = getRuleConfig(systemType); diff --git a/apps/web/utils/rule/learned-patterns.test.ts b/apps/web/utils/rule/learned-patterns.test.ts new file mode 100644 index 0000000000..89cee9ffdc --- /dev/null +++ b/apps/web/utils/rule/learned-patterns.test.ts @@ -0,0 +1,253 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { saveLearnedPattern, saveLearnedPatterns } from "./learned-patterns"; +import prisma from "@/utils/__mocks__/prisma"; +import { GroupItemType, GroupItemSource } from "@/generated/prisma/enums"; +import { isDuplicateError } from "@/utils/prisma-helpers"; + +vi.mock("server-only", () => ({})); +vi.mock("@/utils/prisma"); + +vi.mock("@/utils/prisma-helpers", () => ({ + isDuplicateError: vi.fn(), +})); + +const mockLogger = { + error: vi.fn(), + warn: vi.fn(), + info: vi.fn(), +} as any; + +describe("saveLearnedPattern", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should return early if rule not found", async () => { + vi.mocked(prisma.rule.findUnique).mockResolvedValue(null); + + await saveLearnedPattern({ + emailAccountId: "email-account-id", + from: "test@example.com", + ruleId: "nonexistent-rule", + logger: mockLogger, + }); + + expect(mockLogger.error).toHaveBeenCalledWith("Rule not found", { + ruleId: "nonexistent-rule", + }); + expect(prisma.groupItem.upsert).not.toHaveBeenCalled(); + }); + + it("should use existing groupId when rule has one", async () => { + const existingGroupId = "existing-group-id"; + vi.mocked(prisma.rule.findUnique).mockResolvedValue({ + id: "rule-id", + name: "Test Rule", + groupId: existingGroupId, + } as any); + vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any); + + await saveLearnedPattern({ + emailAccountId: "email-account-id", + from: "test@example.com", + ruleId: "rule-id", + logger: mockLogger, + }); + + expect(prisma.group.create).not.toHaveBeenCalled(); + expect(prisma.groupItem.upsert).toHaveBeenCalledWith({ + where: { + groupId_type_value: { + groupId: existingGroupId, + type: GroupItemType.FROM, + value: "test@example.com", + }, + }, + update: expect.objectContaining({ exclude: false }), + create: expect.objectContaining({ + groupId: existingGroupId, + type: GroupItemType.FROM, + value: "test@example.com", + }), + }); + }); + + it("should create a new group when rule has no groupId", async () => { + const newGroupId = "new-group-id"; + vi.mocked(prisma.rule.findUnique).mockResolvedValue({ + id: "rule-id", + name: "Test Rule", + groupId: null, + } as any); + vi.mocked(prisma.group.create).mockResolvedValue({ + id: newGroupId, + } as any); + vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any); + + await saveLearnedPattern({ + emailAccountId: "email-account-id", + from: "test@example.com", + ruleId: "rule-id", + logger: mockLogger, + }); + + expect(prisma.group.create).toHaveBeenCalledWith({ + data: { + emailAccountId: "email-account-id", + name: "Test Rule", + rule: { connect: { id: "rule-id" } }, + }, + }); + expect(prisma.groupItem.upsert).toHaveBeenCalledWith( + expect.objectContaining({ + where: { + groupId_type_value: { + groupId: newGroupId, + type: GroupItemType.FROM, + value: "test@example.com", + }, + }, + }), + ); + }); + + it("should save pattern with exclude: true", async () => { + vi.mocked(prisma.rule.findUnique).mockResolvedValue({ + id: "rule-id", + name: "Test Rule", + groupId: "group-id", + } as any); + vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any); + + await saveLearnedPattern({ + emailAccountId: "email-account-id", + from: "excluded@example.com", + ruleId: "rule-id", + exclude: true, + logger: mockLogger, + reason: "User excluded", + source: GroupItemSource.USER, + }); + + expect(prisma.groupItem.upsert).toHaveBeenCalledWith({ + where: { + groupId_type_value: { + groupId: "group-id", + type: GroupItemType.FROM, + value: "excluded@example.com", + }, + }, + update: { + exclude: true, + reason: "User excluded", + threadId: undefined, + messageId: undefined, + source: GroupItemSource.USER, + }, + create: { + groupId: "group-id", + type: GroupItemType.FROM, + value: "excluded@example.com", + exclude: true, + reason: "User excluded", + threadId: undefined, + messageId: undefined, + source: GroupItemSource.USER, + }, + }); + }); + + it("should handle duplicate group creation by finding existing group", async () => { + const existingGroupId = "existing-group-id"; + vi.mocked(prisma.rule.findUnique) + .mockResolvedValueOnce({ + id: "rule-id", + name: "Test Rule", + groupId: null, + } as any) + .mockResolvedValueOnce({ + groupId: null, + } as any); + + const duplicateError = new Error("Duplicate key"); + vi.mocked(prisma.group.create).mockRejectedValue(duplicateError); + vi.mocked(isDuplicateError).mockReturnValue(true); + vi.mocked(prisma.group.findUnique).mockResolvedValue({ + id: existingGroupId, + } as any); + vi.mocked(prisma.rule.update).mockResolvedValue({} as any); + vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any); + + await saveLearnedPattern({ + emailAccountId: "email-account-id", + from: "test@example.com", + ruleId: "rule-id", + logger: mockLogger, + }); + + expect(prisma.group.findUnique).toHaveBeenCalledWith({ + where: { + name_emailAccountId: { + name: "Test Rule", + emailAccountId: "email-account-id", + }, + }, + select: { id: true }, + }); + expect(prisma.groupItem.upsert).toHaveBeenCalledWith( + expect.objectContaining({ + where: { + groupId_type_value: { + groupId: existingGroupId, + type: GroupItemType.FROM, + value: "test@example.com", + }, + }, + }), + ); + }); +}); + +describe("saveLearnedPatterns", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should return error if rule not found", async () => { + vi.mocked(prisma.rule.findUnique).mockResolvedValue(null); + + const result = await saveLearnedPatterns({ + emailAccountId: "email-account-id", + ruleName: "Nonexistent Rule", + patterns: [{ type: GroupItemType.FROM, value: "test@example.com" }], + logger: mockLogger, + }); + + expect(result).toEqual({ error: "Rule not found" }); + expect(mockLogger.error).toHaveBeenCalledWith("Rule not found", { + emailAccountId: "email-account-id", + ruleName: "Nonexistent Rule", + }); + }); + + it("should save multiple patterns successfully", async () => { + vi.mocked(prisma.rule.findUnique).mockResolvedValue({ + id: "rule-id", + groupId: "group-id", + } as any); + vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any); + + const result = await saveLearnedPatterns({ + emailAccountId: "email-account-id", + ruleName: "Test Rule", + patterns: [ + { type: GroupItemType.FROM, value: "sender1@example.com" }, + { type: GroupItemType.SUBJECT, value: "Newsletter", exclude: true }, + ], + logger: mockLogger, + }); + + expect(result).toEqual({ success: true }); + expect(prisma.groupItem.upsert).toHaveBeenCalledTimes(2); + }); +}); diff --git a/apps/web/utils/rule/learned-patterns.ts b/apps/web/utils/rule/learned-patterns.ts index 0517c21164..7742aa2b8e 100644 --- a/apps/web/utils/rule/learned-patterns.ts +++ b/apps/web/utils/rule/learned-patterns.ts @@ -1,6 +1,6 @@ import prisma from "@/utils/prisma"; import type { Logger } from "@/utils/logger"; -import { GroupItemType } from "@/generated/prisma/enums"; +import { GroupItemType, type GroupItemSource } from "@/generated/prisma/enums"; import { isDuplicateError } from "@/utils/prisma-helpers"; /** @@ -11,43 +11,41 @@ import { isDuplicateError } from "@/utils/prisma-helpers"; export async function saveLearnedPattern({ emailAccountId, from, - ruleName, + ruleId, + exclude = false, logger, + reason, + threadId, + messageId, + source, }: { emailAccountId: string; from: string; - ruleName: string; + ruleId: string; + exclude?: boolean; logger: Logger; + reason?: string | null; + threadId?: string | null; + messageId?: string | null; + source?: GroupItemSource | null; }) { const rule = await prisma.rule.findUnique({ - where: { - name_emailAccountId: { - name: ruleName, - emailAccountId, - }, - }, - select: { id: true, groupId: true }, + where: { id: ruleId, emailAccountId }, + select: { id: true, name: true, groupId: true }, }); if (!rule) { - logger.error("Rule not found", { emailAccountId, ruleName }); + logger.error("Rule not found", { ruleId }); return; } - let groupId = rule.groupId; - - if (!groupId) { - // Create a new group for this rule if one doesn't exist - const newGroup = await prisma.group.create({ - data: { - emailAccountId, - name: ruleName, - rule: { connect: { id: rule.id } }, - }, - }); - - groupId = newGroup.id; - } + const groupId = await getOrCreateGroupForRule({ + emailAccountId, + ruleId: rule.id, + ruleName: rule.name, + existingGroupId: rule.groupId, + logger, + }); await prisma.groupItem.upsert({ where: { @@ -57,11 +55,22 @@ export async function saveLearnedPattern({ value: from, }, }, - update: {}, + update: { + exclude, + reason, + threadId, + messageId, + source, + }, create: { groupId, type: GroupItemType.FROM, value: from, + exclude, + reason, + threadId, + messageId, + source, }, }); } @@ -100,43 +109,24 @@ export async function saveLearnedPatterns({ return { error: "Rule not found" }; } - let groupId = rule.groupId; - - if (!groupId) { - try { - const newGroup = await prisma.group.create({ - data: { - emailAccountId, - name: ruleName, - rule: { connect: { id: rule.id } }, - }, - }); - - groupId = newGroup.id; - } catch (error) { - if (isDuplicateError(error)) { - logger.error("Group already exists", { emailAccountId, ruleName }); - const newGroup2 = await prisma.group.create({ - data: { - emailAccountId, - name: `${ruleName} (${new Date().toISOString()})`, - rule: { connect: { id: rule.id } }, - }, - }); - groupId = newGroup2.id; - } else { - logger.error("Error creating learned patterns group", { error }); - return { error: "Error creating learned patterns group" }; - } - } + let groupId: string; + try { + groupId = await getOrCreateGroupForRule({ + emailAccountId, + ruleId: rule.id, + ruleName: ruleName, + existingGroupId: rule.groupId, + logger, + }); + } catch (error) { + logger.error("Error creating learned patterns group", { error }); + return { error: "Error creating learned patterns group" }; } const errors: string[] = []; // Process all patterns in a single function for (const pattern of patterns) { - // Store pattern with the exclude flag properly set in the database - // This maps directly to the new exclude field in the GroupItem model try { await prisma.groupItem.upsert({ where: { @@ -175,3 +165,65 @@ export async function saveLearnedPatterns({ return { success: true }; } + +async function getOrCreateGroupForRule({ + emailAccountId, + ruleId, + ruleName, + existingGroupId, + logger, +}: { + emailAccountId: string; + ruleId: string; + ruleName: string; + existingGroupId: string | null; + logger: Logger; +}): Promise { + if (existingGroupId) return existingGroupId; + + // Try to create the group + try { + const newGroup = await prisma.group.create({ + data: { + emailAccountId, + name: ruleName, + rule: { connect: { id: ruleId } }, + }, + }); + return newGroup.id; + } catch (error) { + if (!isDuplicateError(error)) throw error; + } + + // Handle duplicate: check if rule was concurrently updated with a group + const updatedRule = await prisma.rule.findUnique({ + where: { id: ruleId }, + select: { groupId: true }, + }); + if (updatedRule?.groupId) return updatedRule.groupId; + + // Check if a group with the same name exists + const existingGroup = await prisma.group.findUnique({ + where: { name_emailAccountId: { name: ruleName, emailAccountId } }, + select: { id: true }, + }); + + if (existingGroup) { + // Attempt to link it (ignore failures from concurrent updates) + await prisma.rule + .update({ where: { id: ruleId }, data: { groupId: existingGroup.id } }) + .catch((error) => { + logger.warn( + "Failed to link existing group to rule (likely concurrent update)", + { + ruleId, + groupId: existingGroup.id, + error, + }, + ); + }); + return existingGroup.id; + } + + throw new Error(`Failed to create or find group for rule: ${ruleName}`); +}