diff --git a/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx b/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx
index fad0132c15..bcd41bf243 100644
--- a/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx
+++ b/apps/web/app/(app)/[emailAccountId]/assistant/group/ViewLearnedPatterns.tsx
@@ -96,21 +96,6 @@ function ViewGroupInner({ groupId }: { groupId: string }) {
Add pattern
-
- {!!group?.items?.length && (
-
- )}
)}
diff --git a/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx b/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx
index 32e66884e3..51bd127434 100644
--- a/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx
+++ b/apps/web/app/(app)/[emailAccountId]/assistant/settings/WritingStyleSetting.tsx
@@ -115,8 +115,11 @@ function WritingStyleDialog({
registerProps={register("writingStyle")}
error={errors.writingStyle}
placeholder="Typical Length: 2-3 sentences
+
Formality: Informal but professional
+
Common Greeting: Hey,
+
Notable Traits:
- Uses contractions frequently
- Concise and direct responses
diff --git a/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx b/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx
index 8a08080c30..3ccae88683 100644
--- a/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx
+++ b/apps/web/app/(app)/[emailAccountId]/cold-email-blocker/ColdEmailList.tsx
@@ -20,6 +20,7 @@ import { AlertBasic } from "@/components/Alert";
import { Button } from "@/components/ui/button";
import { useSearchParams } from "next/navigation";
import { markNotColdEmailAction } from "@/utils/actions/cold-email";
+import { toggleRuleAction } from "@/utils/actions/rule";
import { Checkbox } from "@/components/Checkbox";
import { useToggleSelect } from "@/hooks/useToggleSelect";
import { ViewEmailButton } from "@/components/ViewEmailButton";
@@ -27,9 +28,9 @@ import { EmailMessageCellWithData } from "@/components/EmailMessageCell";
import { EnableFeatureCard } from "@/components/EnableFeatureCard";
import { toastError, toastSuccess } from "@/components/Toast";
import { useAccount } from "@/providers/EmailAccountProvider";
-import { prefixPath } from "@/utils/path";
import { useRules } from "@/hooks/useRules";
import { isColdEmailBlockerEnabled } from "@/utils/cold-email/cold-email-blocker-enabled";
+import { SystemType } from "@/generated/prisma/enums";
export function ColdEmailList() {
const searchParams = useSearchParams();
@@ -187,18 +188,36 @@ function Row({
function NoColdEmails() {
const { emailAccountId } = useAccount();
- const { data: rules } = useRules();
+ const { data: rules, mutate: mutateRules } = useRules();
+
+ const { executeAsync: enableColdEmailBlocker } = useAction(
+ toggleRuleAction.bind(null, emailAccountId),
+ {
+ onSuccess: () => {
+ toastSuccess({ description: "Cold email blocker enabled!" });
+ mutateRules();
+ },
+ onError: () => {
+ toastError({ description: "Error enabling cold email blocker" });
+ },
+ },
+ );
if (!isColdEmailBlockerEnabled(rules || [])) {
return (
{
+ await enableColdEmailBlocker({
+ systemType: SystemType.COLD_EMAIL,
+ enabled: true,
+ });
+ }}
hideBorder
/>
diff --git a/apps/web/app/api/ai/analyze-sender-pattern/route.ts b/apps/web/app/api/ai/analyze-sender-pattern/route.ts
index 23d53d3134..dc9bfb6082 100644
--- a/apps/web/app/api/ai/analyze-sender-pattern/route.ts
+++ b/apps/web/app/api/ai/analyze-sender-pattern/route.ts
@@ -10,6 +10,7 @@ import { isValidInternalApiKey } from "@/utils/internal-api";
import { extractEmailAddress } from "@/utils/email";
import { getEmailForLLM } from "@/utils/get-email-from-message";
import { saveLearnedPattern } from "@/utils/rule/learned-patterns";
+import { GroupItemSource } from "@/generated/prisma/enums";
import { checkSenderRuleHistory } from "@/utils/rule/check-sender-rule-history";
import { createEmailProvider } from "@/utils/email/provider";
import type { EmailProvider } from "@/utils/email/types";
@@ -177,12 +178,24 @@ async function process({
if (patternResult?.matchedRule) {
// Verify the AI matched the same rule as the historical data
if (patternResult.matchedRule === senderHistory.consistentRuleName) {
- await saveLearnedPattern({
- emailAccountId,
- from,
- ruleName: patternResult.matchedRule,
- logger,
- });
+ const matchedRule = emailAccount.rules.find(
+ (rule) => rule.name === patternResult.matchedRule,
+ );
+
+ if (matchedRule) {
+ await saveLearnedPattern({
+ emailAccountId,
+ from,
+ ruleId: matchedRule.id,
+ logger,
+ source: GroupItemSource.AI,
+ });
+ } else {
+ logger.error("Matched rule not found in email account rules", {
+ ruleName: patternResult.matchedRule,
+ availableRules: emailAccount.rules.map((r) => r.name),
+ });
+ }
} else {
logger.warn("AI suggested different rule than historical data", {
aiRule: patternResult.matchedRule,
diff --git a/apps/web/app/api/google/webhook/process-label-removed-event.test.ts b/apps/web/app/api/google/webhook/process-label-removed-event.test.ts
index 39c895222e..10c95dbef2 100644
--- a/apps/web/app/api/google/webhook/process-label-removed-event.test.ts
+++ b/apps/web/app/api/google/webhook/process-label-removed-event.test.ts
@@ -1,20 +1,27 @@
import { vi, describe, it, expect, beforeEach } from "vitest";
-import { ColdEmailStatus } from "@/generated/prisma/enums";
import { HistoryEventType } from "./types";
import { handleLabelRemovedEvent } from "./process-label-removed-event";
import type { gmail_v1 } from "@googleapis/gmail";
-import { saveLearnedPatterns } from "@/utils/rule/learned-patterns";
-import prisma from "@/utils/__mocks__/prisma";
+import { saveLearnedPattern } from "@/utils/rule/learned-patterns";
import { createScopedLogger } from "@/utils/logger";
+import { GroupItemSource, SystemType } from "@/generated/prisma/enums";
+import prisma from "@/utils/prisma";
const logger = createScopedLogger("test");
vi.mock("server-only", () => ({}));
// Mock dependencies
-vi.mock("@/utils/prisma");
+vi.mock("@/utils/prisma", () => ({
+ default: {
+ rule: {
+ findFirst: vi.fn(),
+ },
+ },
+}));
+
vi.mock("@/utils/rule/learned-patterns", () => ({
- saveLearnedPatterns: vi.fn().mockResolvedValue(undefined),
+ saveLearnedPattern: vi.fn().mockResolvedValue(undefined),
}));
vi.mock("@/utils/gmail/label", () => ({
@@ -89,88 +96,42 @@ describe("process-label-removed-event", () => {
};
describe("handleLabelRemovedEvent", () => {
- it("should process Cold Email label removal and update ColdEmail status", async () => {
- prisma.coldEmail.upsert.mockResolvedValue({} as any);
+ it("should process Cold Email label removal and call saveLearnedPattern with exclude: true", async () => {
+ vi.mocked(prisma.rule.findFirst).mockResolvedValue({
+ id: "rule-123",
+ systemType: SystemType.COLD_EMAIL,
+ } as any);
const historyItem = createLabelRemovedHistoryItem();
- console.log("Test data:", JSON.stringify(historyItem.item, null, 2));
-
- try {
- await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
- } catch (error) {
- console.error("Function error:", error);
- throw error;
- }
-
- expect(prisma.coldEmail.upsert).toHaveBeenCalledWith({
- where: {
- emailAccountId_fromEmail: {
- emailAccountId: "email-account-id",
- fromEmail: "sender@example.com",
- },
- },
- update: {
- status: ColdEmailStatus.USER_REJECTED_COLD,
- },
- create: {
- status: ColdEmailStatus.USER_REJECTED_COLD,
- fromEmail: "sender@example.com",
- emailAccountId: "email-account-id",
- messageId: "123",
- threadId: "thread-123",
- },
- });
- });
-
- it("should skip learning when Newsletter label is removed (only Cold Email is supported)", async () => {
- const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
- "label-2",
- ]);
-
await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
- expect(saveLearnedPatterns).not.toHaveBeenCalled();
- });
-
- it("should skip learning when To Reply label is removed (only Cold Email is supported)", async () => {
- const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
- "label-4",
- ]);
-
- await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
-
- expect(saveLearnedPatterns).not.toHaveBeenCalled();
- });
-
- it("should skip learning when no executed rule exists (only Cold Email is supported)", async () => {
- const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
- "label-2",
- ]);
-
- await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
-
- expect(saveLearnedPatterns).not.toHaveBeenCalled();
+ expect(saveLearnedPattern).toHaveBeenCalledWith({
+ emailAccountId: "email-account-id",
+ from: "sender@example.com",
+ ruleId: "rule-123",
+ exclude: true,
+ logger: expect.anything(),
+ messageId: "123",
+ threadId: "thread-123",
+ reason: "Label removed",
+ source: GroupItemSource.LABEL_REMOVED,
+ });
});
- it("should skip learning when no matching LABEL action is found (only Cold Email is supported)", async () => {
- const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
- "label-2",
- ]);
+ it("should skip learning when To Reply label is removed (not a learnable rule)", async () => {
+ vi.mocked(prisma.rule.findFirst).mockResolvedValue({
+ id: "rule-456",
+ systemType: SystemType.TO_REPLY,
+ } as any);
- await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
-
- expect(saveLearnedPatterns).not.toHaveBeenCalled();
- });
-
- it("should handle multiple label removals in a single event (only Cold Email is supported)", async () => {
const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
- "label-3",
+ "label-4",
]);
await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
- expect(saveLearnedPatterns).not.toHaveBeenCalled();
+ expect(saveLearnedPattern).not.toHaveBeenCalled();
});
it("should skip processing when only system labels are removed", async () => {
@@ -183,7 +144,7 @@ describe("process-label-removed-event", () => {
// Should not try to fetch the message when only system labels removed
expect(mockProvider.getMessage).not.toHaveBeenCalled();
- expect(prisma.coldEmail.upsert).not.toHaveBeenCalled();
+ expect(saveLearnedPattern).not.toHaveBeenCalled();
});
it("should skip processing when DRAFT label is removed (prevents 404 errors)", async () => {
@@ -194,9 +155,8 @@ describe("process-label-removed-event", () => {
await handleLabelRemovedEvent(historyItem, defaultOptions, logger);
- // Should not try to fetch the message (which would fail with 404)
expect(mockProvider.getMessage).not.toHaveBeenCalled();
- expect(prisma.coldEmail.upsert).not.toHaveBeenCalled();
+ expect(saveLearnedPattern).not.toHaveBeenCalled();
});
it("should skip processing when messageId is missing", async () => {
@@ -207,7 +167,7 @@ describe("process-label-removed-event", () => {
await handleLabelRemovedEvent(historyItem, defaultOptions, logger);
- expect(prisma.coldEmail.upsert).not.toHaveBeenCalled();
+ expect(saveLearnedPattern).not.toHaveBeenCalled();
});
it("should skip processing when threadId is missing", async () => {
@@ -218,7 +178,46 @@ describe("process-label-removed-event", () => {
await handleLabelRemovedEvent(historyItem, defaultOptions, logger);
- expect(prisma.coldEmail.upsert).not.toHaveBeenCalled();
+ expect(saveLearnedPattern).not.toHaveBeenCalled();
+ });
+
+ it("should handle multiple label removals in a single event", async () => {
+ vi.mocked(prisma.rule.findFirst)
+ .mockResolvedValueOnce({
+ id: "rule-1",
+ systemType: SystemType.COLD_EMAIL,
+ } as any)
+ .mockResolvedValueOnce({
+ id: "rule-2",
+ systemType: SystemType.NEWSLETTER,
+ } as any);
+
+ const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
+ "label-1",
+ "label-2",
+ ]);
+
+ await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
+
+ expect(saveLearnedPattern).toHaveBeenCalledTimes(2);
+ expect(saveLearnedPattern).toHaveBeenCalledWith(
+ expect.objectContaining({ ruleId: "rule-1" }),
+ );
+ expect(saveLearnedPattern).toHaveBeenCalledWith(
+ expect.objectContaining({ ruleId: "rule-2" }),
+ );
+ });
+
+ it("should skip learning when no rule is found for the removed label", async () => {
+ vi.mocked(prisma.rule.findFirst).mockResolvedValue(null);
+
+ const historyItem = createLabelRemovedHistoryItem("123", "thread-123", [
+ "unknown-label",
+ ]);
+
+ await handleLabelRemovedEvent(historyItem.item, defaultOptions, logger);
+
+ expect(saveLearnedPattern).not.toHaveBeenCalled();
});
});
});
diff --git a/apps/web/app/api/google/webhook/process-label-removed-event.ts b/apps/web/app/api/google/webhook/process-label-removed-event.ts
index 3e2bbb5b5e..2ae4c9611a 100644
--- a/apps/web/app/api/google/webhook/process-label-removed-event.ts
+++ b/apps/web/app/api/google/webhook/process-label-removed-event.ts
@@ -1,12 +1,13 @@
import type { gmail_v1 } from "@googleapis/gmail";
-import prisma from "@/utils/prisma";
-import { ColdEmailStatus, SystemType } from "@/generated/prisma/enums";
+import { GroupItemSource, ActionType } from "@/generated/prisma/enums";
+import { saveLearnedPattern } from "@/utils/rule/learned-patterns";
import { extractEmailAddress } from "@/utils/email";
import type { EmailAccountWithAI } from "@/utils/llms/types";
import type { EmailProvider } from "@/utils/email/types";
import { GmailLabel } from "@/utils/gmail/label";
-import { getRuleLabel } from "@/utils/rule/consts";
+import { shouldLearnFromLabelRemoval } from "@/utils/rule/consts";
import type { Logger } from "@/utils/logger";
+import prisma from "@/utils/prisma";
import {
isGmailRateLimitExceededError,
isGmailQuotaExceededError,
@@ -115,22 +116,10 @@ export async function handleLabelRemovedEvent(
return;
}
- const labels = await provider.getLabels();
-
for (const labelId of removedLabelIds) {
- const label = labels?.find((l) => l.id === labelId);
- const labelName = label?.name;
-
- if (!labelName) {
- logger.info("Skipping label removal - missing label name", {
- labelId,
- });
- continue;
- }
-
try {
await learnFromRemovedLabel({
- labelName,
+ labelId,
sender,
messageId,
threadId,
@@ -140,7 +129,7 @@ export async function handleLabelRemovedEvent(
} catch (error) {
logger.error("Error learning from label removal", {
error,
- labelName,
+ labelId,
removedLabelIds,
});
}
@@ -148,21 +137,21 @@ export async function handleLabelRemovedEvent(
}
async function learnFromRemovedLabel({
- labelName,
+ labelId,
sender,
messageId,
threadId,
emailAccountId,
logger,
}: {
- labelName: string;
+ labelId: string;
sender: string | null;
messageId: string;
threadId: string;
emailAccountId: string;
logger: Logger;
}) {
- logger = logger.with({ labelName, sender });
+ logger = logger.with({ labelId, sender });
// Can't learn patterns without knowing who to exclude
if (!sender) {
@@ -170,28 +159,41 @@ async function learnFromRemovedLabel({
return;
}
- if (labelName === getRuleLabel(SystemType.COLD_EMAIL)) {
- logger.info("Processing Cold Email label removal");
-
- await prisma.coldEmail.upsert({
- where: {
- emailAccountId_fromEmail: {
- emailAccountId,
- fromEmail: sender,
+ // Find rule with matching label action
+ const rule = await prisma.rule.findFirst({
+ where: {
+ emailAccountId,
+ systemType: { not: null },
+ actions: {
+ some: {
+ labelId: labelId,
+ type: ActionType.LABEL,
},
},
- update: {
- status: ColdEmailStatus.USER_REJECTED_COLD,
- },
- create: {
- status: ColdEmailStatus.USER_REJECTED_COLD,
- fromEmail: sender,
- emailAccountId,
- messageId,
- threadId,
- },
- });
+ },
+ select: { id: true, systemType: true },
+ });
+ if (!rule?.systemType || !shouldLearnFromLabelRemoval(rule.systemType)) {
+ logger.info("Label removal does not match a learnable system rule", {
+ systemType: rule?.systemType,
+ });
return;
}
+
+ logger.info("Processing label removal for learning", {
+ systemType: rule.systemType,
+ });
+
+ await saveLearnedPattern({
+ emailAccountId,
+ from: sender,
+ ruleId: rule.id,
+ exclude: true,
+ logger,
+ messageId,
+ threadId,
+ reason: "Label removed",
+ source: GroupItemSource.LABEL_REMOVED,
+ });
}
diff --git a/apps/web/app/api/resend/summary/route.ts b/apps/web/app/api/resend/summary/route.ts
index 8679286cd9..7fe1918494 100644
--- a/apps/web/app/api/resend/summary/route.ts
+++ b/apps/web/app/api/resend/summary/route.ts
@@ -7,7 +7,7 @@ import { env } from "@/env";
import { hasCronSecret } from "@/utils/cron";
import { captureException } from "@/utils/error";
import prisma from "@/utils/prisma";
-import { ThreadTrackerType } from "@/generated/prisma/enums";
+import { SystemType, ThreadTrackerType } from "@/generated/prisma/enums";
import type { Logger } from "@/utils/logger";
import { getMessagesBatch } from "@/utils/gmail/message";
import { decodeSnippet } from "@/utils/gmail/decode";
@@ -110,7 +110,6 @@ async function sendEmail({
where: { id: emailAccountId },
select: {
email: true,
- coldEmails: { where: { createdAt: { gt: cutOffDate } } },
account: {
select: {
access_token: true,
@@ -124,84 +123,78 @@ async function sendEmail({
return { success: false };
}
- if (emailAccount) {
- logger.info("Email account found");
- } else {
- logger.error("Email account not found or cutoff date is in the future", {
- cutOffDate,
- });
- return { success: true };
- }
-
- // Get counts and recent threads for each type
- const [
- counts,
- needsReply,
- awaitingReply,
- // needsAction
- ] = await Promise.all([
- // total count
- // NOTE: should really be distinct by threadId. this will cause a mismatch in some cases
- prisma.threadTracker.groupBy({
- by: ["type"],
- where: {
- emailAccountId,
- resolved: false,
- },
- _count: true,
- }),
- // needs reply
- prisma.threadTracker.findMany({
- where: {
+ const coldEmailRule = await prisma.rule.findUnique({
+ where: {
+ emailAccountId_systemType: {
emailAccountId,
- type: ThreadTrackerType.NEEDS_REPLY,
- resolved: false,
+ systemType: SystemType.COLD_EMAIL,
},
- orderBy: { sentAt: "desc" },
- take: 20,
- distinct: ["threadId"],
- }),
- // awaiting reply
- prisma.threadTracker.findMany({
- where: {
- emailAccountId,
- type: ThreadTrackerType.AWAITING,
- resolved: false,
- // only show emails that are more than 3 days overdue
- sentAt: { lt: subHours(new Date(), 24 * 3) },
- },
- orderBy: { sentAt: "desc" },
- take: 20,
- distinct: ["threadId"],
- }),
- // needs action - currently not used
- // prisma.threadTracker.findMany({
- // where: {
- // userId: user.id,
- // type: ThreadTrackerType.NEEDS_ACTION,
- // resolved: false,
- // },
- // orderBy: { sentAt: "desc" },
- // take: 20,
- // distinct: ["threadId"],
- // }),
- ]);
+ },
+ select: { id: true },
+ });
+
+ // Get counts and recent threads for each type
+ const [counts, needsReply, awaitingReply, coldExecutedRules] =
+ await Promise.all([
+ // total count
+ // NOTE: should really be distinct by threadId. this will cause a mismatch in some cases
+ prisma.threadTracker.groupBy({
+ by: ["type"],
+ where: {
+ emailAccountId,
+ resolved: false,
+ },
+ _count: true,
+ }),
+ // needs reply
+ prisma.threadTracker.findMany({
+ where: {
+ emailAccountId,
+ type: ThreadTrackerType.NEEDS_REPLY,
+ resolved: false,
+ },
+ orderBy: { sentAt: "desc" },
+ take: 20,
+ distinct: ["threadId"],
+ }),
+ // awaiting reply
+ prisma.threadTracker.findMany({
+ where: {
+ emailAccountId,
+ type: ThreadTrackerType.AWAITING,
+ resolved: false,
+ // only show emails that are more than 3 days overdue
+ sentAt: { lt: subHours(new Date(), 24 * 3) },
+ },
+ orderBy: { sentAt: "desc" },
+ take: 20,
+ distinct: ["threadId"],
+ }),
+ // cold emails
+ coldEmailRule
+ ? prisma.executedRule.findMany({
+ where: {
+ ruleId: coldEmailRule.id,
+ automated: true,
+ createdAt: { gt: cutOffDate },
+ },
+ select: {
+ messageId: true,
+ createdAt: true,
+ },
+ })
+ : Promise.resolve([]),
+ ]);
const typeCounts = Object.fromEntries(
counts.map((count) => [count.type, count._count]),
);
- const coldEmailers = emailAccount.coldEmails.map((e) => ({
- from: e.fromEmail,
- subject: "",
- sentAt: e.createdAt,
- }));
-
// get messages
const messageIds = [
...needsReply.map((m) => m.messageId),
...awaitingReply.map((m) => m.messageId),
- // ...needsAction.map((m) => m.messageId),
+ ...coldExecutedRules.map((r) => r.messageId),
];
logger.info("Getting messages", {
@@ -237,14 +230,14 @@ async function sendEmail({
};
});
- // const recentNeedsAction = needsAction.map((t) => {
- // const message = messageMap[t.messageId];
- // return {
- // from: message?.headers.from || "Unknown",
- // subject: decodeSnippet(message?.snippet) || "",
- // sentAt: t.sentAt,
- // };
- // });
+ const coldEmailers = coldExecutedRules.map((r) => {
+ const message = messageMap[r.messageId];
+ return {
+ from: message?.headers.from || "Unknown",
+ subject: decodeSnippet(message?.snippet) || "",
+ sentAt: r.createdAt,
+ };
+ });
const shouldSendEmail = !!(
coldEmailers.length ||
@@ -281,7 +274,6 @@ async function sendEmail({
needsActionCount: typeCounts[ThreadTrackerType.NEEDS_ACTION],
needsReply: recentNeedsReply,
awaitingReply: recentAwaitingReply,
- // needsAction: recentNeedsAction,
unsubscribeToken: token,
},
});
diff --git a/apps/web/app/api/user/cold-email/route.ts b/apps/web/app/api/user/cold-email/route.ts
index d7047e9711..3f517fd500 100644
--- a/apps/web/app/api/user/cold-email/route.ts
+++ b/apps/web/app/api/user/cold-email/route.ts
@@ -1,7 +1,11 @@
import { NextResponse } from "next/server";
import prisma from "@/utils/prisma";
import { withEmailAccount } from "@/utils/middleware";
-import { ColdEmailStatus } from "@/generated/prisma/enums";
+import {
+ ColdEmailStatus,
+ GroupItemType,
+ SystemType,
+} from "@/generated/prisma/enums";
const LIMIT = 50;
@@ -14,30 +18,54 @@ async function getColdEmails(
}: { emailAccountId: string; status: ColdEmailStatus },
page: number,
) {
+ const coldEmailRule = await prisma.rule.findUnique({
+ where: {
+ emailAccountId_systemType: {
+ emailAccountId,
+ systemType: SystemType.COLD_EMAIL,
+ },
+ },
+ select: { id: true, groupId: true },
+ });
+
+ if (!coldEmailRule?.groupId) {
+ return { coldEmails: [], totalPages: 0 };
+ }
+
const where = {
- emailAccountId,
- status,
+ groupId: coldEmailRule.groupId,
+ type: GroupItemType.FROM,
+ exclude: status === ColdEmailStatus.USER_REJECTED_COLD,
};
- const [coldEmails, count] = await Promise.all([
- prisma.coldEmail.findMany({
+ const [groupItems, count] = await Promise.all([
+ prisma.groupItem.findMany({
where,
take: LIMIT,
skip: (page - 1) * LIMIT,
orderBy: { createdAt: "desc" },
select: {
id: true,
- fromEmail: true,
- status: true,
+ value: true,
createdAt: true,
reason: true,
threadId: true,
messageId: true,
},
}),
- prisma.coldEmail.count({ where }),
+ prisma.groupItem.count({ where }),
]);
+ const coldEmails = groupItems.map((item) => ({
+ id: item.id,
+ fromEmail: item.value,
+ status: status,
+ createdAt: item.createdAt,
+ reason: item.reason,
+ threadId: item.threadId,
+ messageId: item.messageId,
+ }));
+
return { coldEmails, totalPages: Math.ceil(count / LIMIT) };
}
diff --git a/apps/web/components/EmailMessageCell.tsx b/apps/web/components/EmailMessageCell.tsx
index 84aebc7e66..f3b7ac29b9 100644
--- a/apps/web/components/EmailMessageCell.tsx
+++ b/apps/web/components/EmailMessageCell.tsx
@@ -144,7 +144,8 @@ export function EmailMessageCellWithData({
}) {
const { data, isLoading, error } = useThread({ id: threadId });
- const firstMessage = data?.thread.messages?.[0];
+ const firstMessage = data?.thread?.messages?.[0];
+ const emailNotFound = !isLoading && !error && !firstMessage;
return (
);
}
diff --git a/apps/web/prisma/migrations/20260103000000_migrate_cold_emails_to_group_items/migration.sql b/apps/web/prisma/migrations/20260103000000_migrate_cold_emails_to_group_items/migration.sql
new file mode 100644
index 0000000000..e977ea830d
--- /dev/null
+++ b/apps/web/prisma/migrations/20260103000000_migrate_cold_emails_to_group_items/migration.sql
@@ -0,0 +1,97 @@
+-- CreateEnum
+CREATE TYPE "GroupItemSource" AS ENUM ('AI', 'USER');
+
+-- AlterTable
+ALTER TABLE "GroupItem" ADD COLUMN "messageId" TEXT,
+ADD COLUMN "reason" TEXT,
+ADD COLUMN "source" "GroupItemSource",
+ADD COLUMN "threadId" TEXT;
+
+-- Migrate ColdEmail data to GroupItem
+-- This migration moves historical cold email data from the deprecated ColdEmail table
+-- to the unified GroupItem table (learned patterns system)
+
+-- Step 1: Create Groups for Cold Email rules that don't have one yet
+-- If a group with the same name and emailAccountId already exists, reuse it; otherwise create a new group
+DO $$
+DECLARE
+ rule_record RECORD;
+ new_group_id TEXT;
+ group_name TEXT;
+BEGIN
+ FOR rule_record IN
+ SELECT r.id, r.name, r."emailAccountId"
+ FROM "Rule" r
+ WHERE r."systemType" = 'COLD_EMAIL'
+ AND r."groupId" IS NULL
+ LOOP
+ new_group_id := gen_random_uuid()::TEXT;
+ group_name := rule_record.name;
+
+ -- Check if a group with this name already exists
+ IF EXISTS (
+ SELECT 1 FROM "Group" g
+ WHERE g.name = group_name AND g."emailAccountId" = rule_record."emailAccountId"
+ ) THEN
+ -- Use existing group if it exists
+ UPDATE "Rule"
+ SET "groupId" = (
+ SELECT id FROM "Group" g
+ WHERE g.name = group_name AND g."emailAccountId" = rule_record."emailAccountId"
+ LIMIT 1
+ )
+ WHERE id = rule_record.id;
+ ELSE
+ -- Create new group
+ INSERT INTO "Group" (id, "createdAt", "updatedAt", name, "emailAccountId")
+ VALUES (new_group_id, NOW(), NOW(), group_name, rule_record."emailAccountId");
+
+ UPDATE "Rule" SET "groupId" = new_group_id WHERE id = rule_record.id;
+ END IF;
+ END LOOP;
+END $$;
+
+-- Step 2: Migrate ColdEmail records to GroupItem
+INSERT INTO "GroupItem" (
+ id,
+ "createdAt",
+ "updatedAt",
+ "groupId",
+ type,
+ value,
+ exclude,
+ reason,
+ "threadId",
+ "messageId",
+ source
+)
+SELECT
+ gen_random_uuid() as id,
+ ce."createdAt",
+ ce."updatedAt",
+ r."groupId",
+ 'FROM'::"GroupItemType" as type,
+ ce."fromEmail" as value,
+ CASE
+ WHEN ce.status = 'USER_REJECTED_COLD' THEN true
+ ELSE false
+ END as exclude,
+ ce.reason,
+ ce."threadId",
+ ce."messageId",
+ CASE
+ WHEN ce.status = 'USER_REJECTED_COLD' THEN 'USER'::"GroupItemSource"
+ ELSE 'AI'::"GroupItemSource"
+ END as source
+FROM "ColdEmail" ce
+JOIN "Rule" r ON r."emailAccountId" = ce."emailAccountId" AND r."systemType" = 'COLD_EMAIL'
+WHERE r."groupId" IS NOT NULL
+ AND ce."fromEmail" IS NOT NULL
+ -- Avoid duplicates: only insert if this pattern doesn't already exist
+ AND NOT EXISTS (
+ SELECT 1 FROM "GroupItem" gi
+ WHERE gi."groupId" = r."groupId"
+ AND gi.type = 'FROM'
+ AND gi.value = ce."fromEmail"
+ );
+
diff --git a/apps/web/prisma/migrations/20260104000000_add_label_removed_to_group_item_source/migration.sql b/apps/web/prisma/migrations/20260104000000_add_label_removed_to_group_item_source/migration.sql
new file mode 100644
index 0000000000..e9a11409f5
--- /dev/null
+++ b/apps/web/prisma/migrations/20260104000000_add_label_removed_to_group_item_source/migration.sql
@@ -0,0 +1,3 @@
+-- AlterEnum
+ALTER TYPE "GroupItemSource" ADD VALUE 'LABEL_REMOVED';
+
diff --git a/apps/web/prisma/schema.prisma b/apps/web/prisma/schema.prisma
index 3cd1d88e14..846623fb59 100644
--- a/apps/web/prisma/schema.prisma
+++ b/apps/web/prisma/schema.prisma
@@ -156,7 +156,7 @@ model EmailAccount {
rules Rule[]
executedRules ExecutedRule[]
newsletters Newsletter[]
- coldEmails ColdEmail[]
+ coldEmails ColdEmail[] // @deprecated - kept for backward compatibility during migration
groups Group[]
categories Category[]
threadTrackers ThreadTracker[]
@@ -277,12 +277,12 @@ model DigestItem {
digest Digest @relation(fields: [digestId], references: [id], onDelete: Cascade)
actionId String?
action ExecutedAction? @relation(fields: [actionId], references: [id], onDelete: Cascade)
- coldEmailId String?
- coldEmail ColdEmail? @relation(fields: [coldEmailId], references: [id])
+ coldEmailId String? // @deprecated
+ coldEmail ColdEmail? @relation(fields: [coldEmailId], references: [id]) // @deprecated
@@unique([digestId, threadId, messageId])
@@index([actionId])
- @@index([coldEmailId])
+ @@index([coldEmailId]) // @deprecated
}
model Schedule {
@@ -620,6 +620,12 @@ model GroupItem {
value String // eg "@gmail.com", "matt@gmail.com", "Receipt from"
exclude Boolean @default(false) // Whether this pattern should be excluded rather than included
+ // Optional context for why/how this pattern was learned.
+ reason String?
+ threadId String?
+ messageId String?
+ source GroupItemSource? // provides value for UI/audit.
+
@@unique([groupId, type, value])
}
@@ -664,6 +670,9 @@ model Newsletter {
@@index([categoryId])
}
+// @deprecated - ColdEmail data is being migrated to GroupItem (learned patterns).
+// This model is kept for backward compatibility during the migration period.
+// Once all users have run the migration, this model can be deleted.
model ColdEmail {
id String @id @default(cuid())
createdAt DateTime @default(now())
@@ -1189,3 +1198,9 @@ enum MeetingBriefingStatus {
FAILED
SKIPPED
}
+
+enum GroupItemSource {
+ AI
+ USER
+ LABEL_REMOVED
+}
diff --git a/apps/web/utils/actions/cold-email.ts b/apps/web/utils/actions/cold-email.ts
index ef633f29dc..6f877dadde 100644
--- a/apps/web/utils/actions/cold-email.ts
+++ b/apps/web/utils/actions/cold-email.ts
@@ -1,7 +1,7 @@
"use server";
import prisma from "@/utils/prisma";
-import { ColdEmailStatus, SystemType } from "@/generated/prisma/enums";
+import { GroupItemSource } from "@/generated/prisma/enums";
import { emailToContent } from "@/utils/mail";
import { isColdEmail } from "@/utils/cold-email/is-cold-email";
import {
@@ -13,8 +13,8 @@ import { SafeError } from "@/utils/error";
import { createEmailProvider } from "@/utils/email/provider";
import type { EmailProvider } from "@/utils/email/types";
import { getColdEmailRule } from "@/utils/cold-email/cold-email-rule";
-import { getRuleLabel } from "@/utils/rule/consts";
import { internalDateToDate } from "@/utils/date";
+import { saveLearnedPattern } from "@/utils/rule/learned-patterns";
export const markNotColdEmailAction = actionClient
.metadata({ name: "markNotColdEmail" })
@@ -24,74 +24,52 @@ export const markNotColdEmailAction = actionClient
ctx: { emailAccountId, provider, logger },
parsedInput: { sender },
}) => {
- const emailProvider = await createEmailProvider({
- emailAccountId,
- provider,
- logger,
- });
+ const [emailProvider, coldEmailRule] = await Promise.all([
+ createEmailProvider({
+ emailAccountId,
+ provider,
+ logger,
+ }),
+ getColdEmailRule(emailAccountId),
+ ]);
+
+ if (!coldEmailRule) {
+ throw new SafeError("Cold email rule not found");
+ }
await Promise.all([
- prisma.coldEmail.update({
- where: {
- emailAccountId_fromEmail: {
- emailAccountId,
- fromEmail: sender,
- },
- },
- data: {
- status: ColdEmailStatus.USER_REJECTED_COLD,
- },
+ // Mark as excluded so AI doesn't match it again
+ saveLearnedPattern({
+ emailAccountId,
+ from: sender,
+ ruleId: coldEmailRule.id,
+ exclude: true,
+ logger,
+ source: GroupItemSource.USER,
}),
- removeColdEmailLabelFromSender(emailAccountId, emailProvider, sender),
+ removeColdEmailLabelFromSender(emailProvider, sender, coldEmailRule),
]);
},
);
-/**
- * Helper function to get threads from a specific sender using the email provider
- */
-async function getThreadsFromSender(
- emailProvider: EmailProvider,
- sender: string,
- labelId?: string,
-): Promise<{ id: string }[]> {
- const { threads } = await emailProvider.getThreadsWithQuery({
- query: {
- fromEmail: sender,
- labelId,
- },
- maxResults: 100,
- });
-
- return threads.map((thread) => ({ id: thread.id }));
-}
-
async function removeColdEmailLabelFromSender(
- emailAccountId: string,
emailProvider: EmailProvider,
sender: string,
+ coldEmailRule: { actions: { labelId: string | null }[] },
) {
- // 1. find cold email label
- // 2. find emails from sender
- // 3. remove cold email label from emails
-
- const coldEmailRule = await getColdEmailRule(emailAccountId);
- if (!coldEmailRule) return;
+ const labelIds = coldEmailRule.actions
+ .map((action) => action.labelId)
+ .filter((id): id is string => Boolean(id));
- const labels = await emailProvider.getLabels();
+ if (labelIds.length === 0) return;
- // NOTE: this doesn't work completely if the user set 2 labels:
- const label =
- labels.find((label) => label.id === coldEmailRule.actions?.[0]?.labelId) ||
- labels.find((label) => label.name === getRuleLabel(SystemType.COLD_EMAIL));
-
- if (!label?.id) return;
-
- const threads = await getThreadsFromSender(emailProvider, sender, label.id);
+ const { threads } = await emailProvider.getThreadsWithQuery({
+ query: { fromEmail: sender },
+ maxResults: 100,
+ });
for (const thread of threads) {
- if (!thread.id) continue;
- await emailProvider.removeThreadLabel(thread.id, label.id);
+ await emailProvider.removeThreadLabels(thread.id, labelIds);
}
}
diff --git a/apps/web/utils/ai/choose-rule/match-rules.test.ts b/apps/web/utils/ai/choose-rule/match-rules.test.ts
index bd7076ce7e..e8c6bdb2f2 100644
--- a/apps/web/utils/ai/choose-rule/match-rules.test.ts
+++ b/apps/web/utils/ai/choose-rule/match-rules.test.ts
@@ -1653,34 +1653,121 @@ describe("filterToReplyPreset", () => {
});
function getRule(overrides: Partial = {}): RuleWithActions {
+ const {
+ id = "r123",
+ createdAt = new Date(),
+ updatedAt = new Date(),
+ name = "Rule Name",
+ enabled = true,
+ automate = true,
+ runOnThreads = true,
+ emailAccountId = "emailAccountId",
+ conditionalOperator = LogicalOperator.AND,
+ instructions = null,
+ groupId = null,
+ from = null,
+ to = null,
+ subject = null,
+ body = null,
+ categoryFilterType = null,
+ systemType = null,
+ promptText = null,
+ actions = [],
+ } = overrides;
+
return {
- id: "r123",
- userId: "userId",
- runOnThreads: true,
- conditionalOperator: LogicalOperator.AND,
- type: null,
- systemType: null,
- ...overrides,
- } as RuleWithActions;
+ id,
+ createdAt,
+ updatedAt,
+ name,
+ enabled,
+ automate,
+ runOnThreads,
+ emailAccountId,
+ conditionalOperator,
+ instructions,
+ groupId,
+ from,
+ to,
+ subject,
+ body,
+ categoryFilterType,
+ systemType,
+ promptText,
+ actions,
+ };
}
function getHeaders(
overrides: Partial = {},
): ParsedMessageHeaders {
+ const {
+ subject = "Subject",
+ from = "from@example.com",
+ to = "to@example.com",
+ cc,
+ bcc,
+ date = new Date().toISOString(),
+ "message-id": messageId,
+ "reply-to": replyTo,
+ "in-reply-to": inReplyTo,
+ references,
+ "list-unsubscribe": listUnsubscribe,
+ } = overrides;
+
return {
- ...overrides,
- } as ParsedMessageHeaders;
+ subject,
+ from,
+ to,
+ cc,
+ bcc,
+ date,
+ "message-id": messageId,
+ "reply-to": replyTo,
+ "in-reply-to": inReplyTo,
+ references,
+ "list-unsubscribe": listUnsubscribe,
+ };
}
function getMessage(overrides: Partial = {}): ParsedMessage {
- const message = {
- id: "m1",
- threadId: "m1",
- headers: getHeaders(),
- ...overrides,
- };
+ const {
+ id = "m1",
+ threadId = "m1",
+ labelIds = [],
+ snippet = "snippet",
+ historyId = "h1",
+ attachments = [],
+ inline = [],
+ headers = getHeaders(),
+ textPlain = "textPlain",
+ textHtml = "textHtml",
+ subject = "subject",
+ date = new Date().toISOString(),
+ conversationIndex = null,
+ internalDate = null,
+ bodyContentType,
+ rawRecipients,
+ } = overrides;
- return message as ParsedMessage;
+ return {
+ id,
+ threadId,
+ labelIds,
+ snippet,
+ historyId,
+ attachments,
+ inline,
+ headers,
+ textPlain,
+ textHtml,
+ subject,
+ date,
+ conversationIndex,
+ internalDate,
+ bodyContentType,
+ rawRecipients,
+ };
}
function getGroup(
@@ -1688,29 +1775,56 @@ function getGroup(
Prisma.GroupGetPayload<{ include: { items: true; rule: true } }>
> = {},
): Prisma.GroupGetPayload<{ include: { items: true; rule: true } }> {
+ const {
+ id = "group1",
+ name = "group",
+ createdAt = new Date(),
+ updatedAt = new Date(),
+ emailAccountId = "emailAccountId",
+ prompt = null,
+ items = [],
+ rule = null,
+ } = overrides;
+
return {
- id: "group1",
- name: "group",
- createdAt: new Date(),
- updatedAt: new Date(),
- emailAccountId: "emailAccountId",
- prompt: null,
- items: [],
- rule: null,
- ...overrides,
+ id,
+ name,
+ createdAt,
+ updatedAt,
+ emailAccountId,
+ prompt,
+ items,
+ rule,
};
}
function getGroupItem(overrides: Partial = {}): GroupItem {
+ const {
+ id = "groupItem1",
+ createdAt = new Date(),
+ updatedAt = new Date(),
+ groupId = "groupId",
+ type = GroupItemType.FROM,
+ value = "test@example.com",
+ exclude = false,
+ reason = null,
+ threadId = null,
+ messageId = null,
+ source = null,
+ } = overrides;
+
return {
- id: "groupItem1",
- createdAt: new Date(),
- updatedAt: new Date(),
- groupId: "groupId",
- type: GroupItemType.FROM,
- value: "test@example.com",
- exclude: false,
- ...overrides,
+ id,
+ createdAt,
+ updatedAt,
+ groupId,
+ type,
+ value,
+ exclude,
+ reason,
+ threadId,
+ messageId,
+ source,
};
}
diff --git a/apps/web/utils/ai/choose-rule/match-rules.ts b/apps/web/utils/ai/choose-rule/match-rules.ts
index 52dcf1f493..36c902478a 100644
--- a/apps/web/utils/ai/choose-rule/match-rules.ts
+++ b/apps/web/utils/ai/choose-rule/match-rules.ts
@@ -84,7 +84,7 @@ export async function findMatchingRules({
matchReasons: [{ type: ConditionType.AI }],
},
],
- reasoning: coldEmailResult.reason,
+ reasoning: coldEmailResult.aiReason || coldEmailResult.reason,
};
}
}
diff --git a/apps/web/utils/ai/choose-rule/run-rules.test.ts b/apps/web/utils/ai/choose-rule/run-rules.test.ts
index 67e30b5ec3..0cbc73573e 100644
--- a/apps/web/utils/ai/choose-rule/run-rules.test.ts
+++ b/apps/web/utils/ai/choose-rule/run-rules.test.ts
@@ -38,8 +38,9 @@ vi.mock("@/utils/ai/choose-rule/execute", () => ({
vi.mock("@/utils/reply-tracker/label-helpers", () => ({
removeConflictingThreadStatusLabels: vi.fn(),
}));
-vi.mock("@/utils/cold-email/is-cold-email", () => ({
- saveColdEmail: vi.fn(),
+vi.mock("@/utils/rule/learned-patterns", () => ({
+ saveLearnedPattern: vi.fn(),
+ saveLearnedPatterns: vi.fn(),
}));
vi.mock("@/utils/scheduled-actions/scheduler", () => ({
scheduleDelayedActions: vi.fn(),
diff --git a/apps/web/utils/ai/choose-rule/run-rules.ts b/apps/web/utils/ai/choose-rule/run-rules.ts
index e3dba35139..48aed7fd79 100644
--- a/apps/web/utils/ai/choose-rule/run-rules.ts
+++ b/apps/web/utils/ai/choose-rule/run-rules.ts
@@ -4,6 +4,7 @@ import type { EmailAccountWithAI } from "@/utils/llms/types";
import {
ActionType,
ExecutedRuleStatus,
+ GroupItemSource,
SystemType,
} from "@/generated/prisma/enums";
import type { Rule } from "@/generated/prisma/client";
@@ -34,7 +35,7 @@ import {
updateThreadTrackers,
} from "@/utils/reply-tracker/handle-conversation-status";
import { removeConflictingThreadStatusLabels } from "@/utils/reply-tracker/label-helpers";
-import { saveColdEmail } from "@/utils/cold-email/is-cold-email";
+import { saveLearnedPattern } from "@/utils/rule/learned-patterns";
import { internalDateToDate } from "@/utils/date";
import { ConditionType } from "@/utils/config";
import type { Logger } from "@/utils/logger";
@@ -340,14 +341,17 @@ async function executeMatchedRule(
});
if (rule.systemType === SystemType.COLD_EMAIL) {
- await saveColdEmail({
- email: {
- id: message.id,
- threadId: message.threadId,
- from: message.headers.from,
- },
- emailAccount,
- aiReason: reason ?? null,
+ const from =
+ extractEmailAddress(message.headers.from) || message.headers.from;
+ await saveLearnedPattern({
+ emailAccountId: emailAccount.id,
+ from,
+ ruleId: rule.id,
+ logger,
+ reason,
+ messageId: message.id,
+ threadId: message.threadId,
+ source: GroupItemSource.AI,
});
}
diff --git a/apps/web/utils/ai/choose-rule/types.ts b/apps/web/utils/ai/choose-rule/types.ts
index 9be6ad3e11..da3aad477e 100644
--- a/apps/web/utils/ai/choose-rule/types.ts
+++ b/apps/web/utils/ai/choose-rule/types.ts
@@ -41,7 +41,7 @@ export type MatchingRuleResult = {
/**
* Serializable version of MatchReason for database storage
*/
-export type SerializedMatchReason =
+type SerializedMatchReason =
| { type: "STATIC" }
| {
type: "LEARNED_PATTERN";
diff --git a/apps/web/utils/cold-email/cold-email-rule.ts b/apps/web/utils/cold-email/cold-email-rule.ts
index c228396c8c..501168017b 100644
--- a/apps/web/utils/cold-email/cold-email-rule.ts
+++ b/apps/web/utils/cold-email/cold-email-rule.ts
@@ -17,6 +17,7 @@ export async function getColdEmailRule(emailAccountId: string) {
id: true,
enabled: true,
instructions: true,
+ groupId: true,
actions: {
select: {
type: true,
diff --git a/apps/web/utils/cold-email/is-cold-email.test.ts b/apps/web/utils/cold-email/is-cold-email.test.ts
index a7d3931304..95dc0a67f3 100644
--- a/apps/web/utils/cold-email/is-cold-email.test.ts
+++ b/apps/web/utils/cold-email/is-cold-email.test.ts
@@ -1,20 +1,16 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
-import { isColdEmail, saveColdEmail } from "./is-cold-email";
+import { isColdEmail } from "./is-cold-email";
import { getEmailAccount } from "@/__tests__/helpers";
import type { EmailForLLM } from "@/utils/types";
-import { ColdEmailStatus } from "@/generated/prisma/enums";
-import prisma from "@/utils/prisma";
+import { GroupItemType } from "@/generated/prisma/enums";
+import prisma from "@/utils/__mocks__/prisma";
import { extractEmailAddress } from "@/utils/email";
vi.mock("server-only", () => ({}));
+vi.mock("@/utils/prisma");
-vi.mock("@/utils/prisma", () => ({
- default: {
- coldEmail: {
- findUnique: vi.fn(),
- upsert: vi.fn(),
- },
- },
+vi.mock("./cold-email-rule", () => ({
+ getColdEmailRule: vi.fn(),
}));
vi.mock("@/utils/email", async () => {
@@ -41,34 +37,15 @@ describe("isColdEmail", () => {
it("should recognize a known cold email sender even when from field format differs", async () => {
const emailAccount = getEmailAccount({ id: "test-account-id" });
const normalizedEmail = "cold.sender@example.com";
+ const groupId = "test-group-id";
- // First, simulate saving a cold email with normalized email address
- // This is what saveColdEmail does - it extracts just the email address
- vi.mocked(prisma.coldEmail.upsert).mockResolvedValue({
- id: "cold-email-id",
- emailAccountId: emailAccount.id,
- fromEmail: normalizedEmail,
- status: ColdEmailStatus.AI_LABELED_COLD,
- reason: "Test reason",
- messageId: "msg1",
- threadId: "thread1",
- createdAt: new Date(),
- updatedAt: new Date(),
- });
+ // Mock groupItem lookup
+ vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({
+ id: "group-item-id",
+ exclude: false,
+ } as any);
- await saveColdEmail({
- email: {
- from: normalizedEmail,
- id: "msg1",
- threadId: "thread1",
- },
- emailAccount,
- aiReason: "Test reason",
- });
-
- // Now simulate a second email from the same sender but with a different format
- // This is the bug scenario: the from field has a display name
- const secondEmail: EmailForLLM = {
+ const email: EmailForLLM = {
id: "msg2",
from: `"Cold Sender" <${normalizedEmail}>`,
to: emailAccount.email,
@@ -77,93 +54,91 @@ describe("isColdEmail", () => {
date: new Date(),
};
- // Mock Prisma to return the cold email record when queried with normalized email
- vi.mocked(prisma.coldEmail.findUnique).mockResolvedValue({
- id: "cold-email-id",
- emailAccountId: emailAccount.id,
- fromEmail: normalizedEmail,
- status: ColdEmailStatus.AI_LABELED_COLD,
- reason: "Test reason",
- messageId: "msg1",
- threadId: "thread1",
- createdAt: new Date(),
- updatedAt: new Date(),
- });
-
const result = await isColdEmail({
- email: secondEmail,
+ email,
emailAccount,
provider: mockProvider as never,
- coldEmailRule: null,
+ coldEmailRule: { instructions: "test instructions", groupId },
});
- // This test should pass after the fix - the sender should be recognized as cold
expect(result.isColdEmail).toBe(true);
expect(result.reason).toBe("ai-already-labeled");
- // Verify that findUnique was called with the normalized email address
- expect(prisma.coldEmail.findUnique).toHaveBeenCalledWith({
+ // Verify that findFirst was called with the normalized email address
+ expect(prisma.groupItem.findFirst).toHaveBeenCalledWith({
where: {
- emailAccountId_fromEmail: {
- emailAccountId: emailAccount.id,
- fromEmail: normalizedEmail,
- },
- status: ColdEmailStatus.AI_LABELED_COLD,
+ groupId,
+ type: GroupItemType.FROM,
+ value: normalizedEmail,
+ },
+ select: { exclude: true },
+ });
+ });
+
+ it("should return excluded when sender is explicitly excluded from cold email blocker", async () => {
+ const emailAccount = getEmailAccount({ id: "test-account-id" });
+ const normalizedEmail = "excluded.sender@example.com";
+ const groupId = "test-group-id";
+
+ // Mock groupItem lookup with exclude: true
+ vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({
+ id: "group-item-id",
+ exclude: true,
+ } as any);
+
+ const email: EmailForLLM = {
+ id: "msg-excluded",
+ from: `"Excluded Sender" <${normalizedEmail}>`,
+ to: emailAccount.email,
+ subject: "Not a cold email",
+ content: "This sender was explicitly excluded",
+ date: new Date(),
+ };
+
+ const result = await isColdEmail({
+ email,
+ emailAccount,
+ provider: mockProvider as never,
+ coldEmailRule: { instructions: "test instructions", groupId },
+ });
+
+ expect(result.isColdEmail).toBe(false);
+ expect(result.reason).toBe("excluded");
+
+ expect(prisma.groupItem.findFirst).toHaveBeenCalledWith({
+ where: {
+ groupId,
+ type: GroupItemType.FROM,
+ value: normalizedEmail,
},
- select: { id: true },
+ select: { exclude: true },
});
});
it("should handle various email formats consistently", async () => {
const emailAccount = getEmailAccount({ id: "test-account-id" });
const normalizedEmail = "sender@example.com";
+ const groupId = "test-group-id";
+
+ vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({
+ id: "group-item-id",
+ exclude: false,
+ } as any);
- // Test different from field formats that should all resolve to the same normalized email
const emailFormats = [
normalizedEmail,
`<${normalizedEmail}>`,
`"Display Name" <${normalizedEmail}>`,
`Display Name <${normalizedEmail}>`,
- ` ${normalizedEmail} `, // with spaces
+ ` ${normalizedEmail} `,
];
- // Mock Prisma to return cold email for normalized email
- vi.mocked(prisma.coldEmail.findUnique).mockImplementation(
- (args) =>
- new Promise((resolve) => {
- const where = args?.where as
- | {
- emailAccountId_fromEmail: {
- emailAccountId: string;
- fromEmail: string;
- };
- status: ColdEmailStatus;
- }
- | undefined;
-
- if (
- where?.emailAccountId_fromEmail.fromEmail === normalizedEmail &&
- where.status === ColdEmailStatus.AI_LABELED_COLD
- ) {
- resolve({
- id: "cold-email-id",
- emailAccountId: emailAccount.id,
- fromEmail: normalizedEmail,
- status: ColdEmailStatus.AI_LABELED_COLD,
- reason: "Test reason",
- messageId: "msg1",
- threadId: "thread1",
- createdAt: new Date(),
- updatedAt: new Date(),
- } as never);
- } else {
- resolve(null as never);
- }
- }) as never,
- );
-
for (const fromFormat of emailFormats) {
vi.clearAllMocks();
+ vi.mocked(prisma.groupItem.findFirst).mockResolvedValue({
+ id: "group-item-id",
+ exclude: false,
+ } as any);
const email: EmailForLLM = {
id: "msg-test",
@@ -178,23 +153,21 @@ describe("isColdEmail", () => {
email,
emailAccount,
provider: mockProvider as never,
- coldEmailRule: null,
+ coldEmailRule: { instructions: "test instructions", groupId },
});
expect(result.isColdEmail).toBe(true);
expect(result.reason).toBe("ai-already-labeled");
- // Verify extractEmailAddress was used to normalize
- const expectedNormalized = extractEmailAddress(fromFormat);
- expect(prisma.coldEmail.findUnique).toHaveBeenCalledWith({
+ const expectedNormalized =
+ extractEmailAddress(fromFormat) || fromFormat.trim();
+ expect(prisma.groupItem.findFirst).toHaveBeenCalledWith({
where: {
- emailAccountId_fromEmail: {
- emailAccountId: emailAccount.id,
- fromEmail: expectedNormalized,
- },
- status: ColdEmailStatus.AI_LABELED_COLD,
+ groupId,
+ type: GroupItemType.FROM,
+ value: expectedNormalized,
},
- select: { id: true },
+ select: { exclude: true },
});
}
});
diff --git a/apps/web/utils/cold-email/is-cold-email.ts b/apps/web/utils/cold-email/is-cold-email.ts
index 5988da8edc..89e976ed9f 100644
--- a/apps/web/utils/cold-email/is-cold-email.ts
+++ b/apps/web/utils/cold-email/is-cold-email.ts
@@ -1,7 +1,7 @@
import { z } from "zod";
import type { EmailAccountWithAI } from "@/utils/llms/types";
-import type { ColdEmail, Rule } from "@/generated/prisma/client";
-import { ColdEmailStatus } from "@/generated/prisma/enums";
+import type { Rule } from "@/generated/prisma/client";
+import { GroupItemType } from "@/generated/prisma/enums";
import prisma from "@/utils/prisma";
import { DEFAULT_COLD_EMAIL_PROMPT } from "@/utils/cold-email/prompt";
import { stringifyEmail } from "@/utils/stringify-email";
@@ -14,7 +14,11 @@ import { extractEmailAddress } from "@/utils/email";
export const COLD_EMAIL_FOLDER_NAME = "Cold Emails";
-type ColdEmailBlockerReason = "hasPreviousEmail" | "ai" | "ai-already-labeled";
+type ColdEmailBlockerReason =
+ | "hasPreviousEmail"
+ | "ai"
+ | "ai-already-labeled"
+ | "excluded";
export async function isColdEmail({
email,
@@ -27,7 +31,7 @@ export async function isColdEmail({
emailAccount: EmailAccountWithAI;
provider: EmailProvider;
modelType?: ModelType;
- coldEmailRule: Pick | null;
+ coldEmailRule: Pick | null;
}): Promise<{
isColdEmail: boolean;
reason: ColdEmailBlockerReason;
@@ -43,16 +47,31 @@ export async function isColdEmail({
logger.info("Checking is cold email");
// Check if we marked it as a cold email already
- const isColdEmailer = await isKnownColdEmailSender({
- from: email.from,
- emailAccountId: emailAccount.id,
- });
+ const groupId = coldEmailRule?.groupId;
+ let patternMatch: { exclude: boolean } | null = null;
+
+ if (groupId) {
+ const normalizedFrom = extractEmailAddress(email.from) || email.from;
+ patternMatch = await prisma.groupItem.findFirst({
+ where: {
+ groupId,
+ type: GroupItemType.FROM,
+ value: normalizedFrom,
+ },
+ select: { exclude: true },
+ });
+ }
- if (isColdEmailer) {
- logger.info("Known cold email sender", {
+ if (patternMatch && !patternMatch.exclude) {
+ logger.info("Known cold email sender", { from: email.from });
+ return { isColdEmail: true, reason: "ai-already-labeled" };
+ }
+
+ if (patternMatch?.exclude) {
+ logger.info("Sender explicitly excluded from cold email blocker", {
from: email.from,
});
- return { isColdEmail: true, reason: "ai-already-labeled" };
+ return { isColdEmail: false, reason: "excluded" };
}
const hasPreviousEmail =
@@ -88,28 +107,6 @@ export async function isColdEmail({
};
}
-async function isKnownColdEmailSender({
- from,
- emailAccountId,
-}: {
- from: string;
- emailAccountId: string;
-}) {
- const normalizedFrom = extractEmailAddress(from) || from;
-
- const coldEmail = await prisma.coldEmail.findUnique({
- where: {
- emailAccountId_fromEmail: {
- emailAccountId,
- fromEmail: normalizedFrom,
- },
- status: ColdEmailStatus.AI_LABELED_COLD,
- },
- select: { id: true },
- });
- return !!coldEmail;
-}
-
async function aiIsColdEmail(
email: EmailForLLM,
emailAccount: EmailAccountWithAI,
@@ -161,33 +158,3 @@ ${stringifyEmail(email, 500)}
return response.object;
}
-
-export async function saveColdEmail({
- email,
- emailAccount,
- aiReason,
-}: {
- email: { from: string; id: string; threadId: string };
- emailAccount: EmailAccountWithAI;
- aiReason: string | null;
-}): Promise {
- const from = extractEmailAddress(email.from) || email.from;
-
- return await prisma.coldEmail.upsert({
- where: {
- emailAccountId_fromEmail: {
- emailAccountId: emailAccount.id,
- fromEmail: from,
- },
- },
- update: { status: ColdEmailStatus.AI_LABELED_COLD },
- create: {
- status: ColdEmailStatus.AI_LABELED_COLD,
- fromEmail: from,
- emailAccountId: emailAccount.id,
- reason: aiReason,
- messageId: email.id,
- threadId: email.threadId,
- },
- });
-}
diff --git a/apps/web/utils/rule/consts.ts b/apps/web/utils/rule/consts.ts
index 7167a696e7..50658248d6 100644
--- a/apps/web/utils/rule/consts.ts
+++ b/apps/web/utils/rule/consts.ts
@@ -14,6 +14,7 @@ const ruleConfig: Record<
categoryAction: "label" | "label_archive" | "move_folder";
categoryActionMicrosoft?: "move_folder";
tooltipText: string;
+ shouldLearn: boolean;
}
> = {
[SystemType.TO_REPLY]: {
@@ -25,6 +26,7 @@ const ruleConfig: Record<
categoryAction: "label",
tooltipText:
"Emails you need to reply to and those where you're awaiting a reply. The label will update automatically as the conversation progresses",
+ shouldLearn: false,
},
[SystemType.FYI]: {
name: "FYI",
@@ -33,6 +35,7 @@ const ruleConfig: Record<
runOnThreads: true,
categoryAction: "label",
tooltipText: "",
+ shouldLearn: false,
},
[SystemType.AWAITING_REPLY]: {
name: "Awaiting Reply",
@@ -41,6 +44,7 @@ const ruleConfig: Record<
runOnThreads: true,
categoryAction: "label",
tooltipText: "",
+ shouldLearn: false,
},
[SystemType.ACTIONED]: {
name: "Actioned",
@@ -49,6 +53,7 @@ const ruleConfig: Record<
runOnThreads: true,
categoryAction: "label",
tooltipText: "",
+ shouldLearn: false,
},
[SystemType.NEWSLETTER]: {
name: "Newsletter",
@@ -59,6 +64,7 @@ const ruleConfig: Record<
categoryAction: "label",
categoryActionMicrosoft: "move_folder",
tooltipText: "Newsletters, blogs, and publications",
+ shouldLearn: true,
},
[SystemType.MARKETING]: {
name: "Marketing",
@@ -69,6 +75,7 @@ const ruleConfig: Record<
categoryAction: "label_archive",
categoryActionMicrosoft: "move_folder",
tooltipText: "Promotional emails about sales and offers",
+ shouldLearn: true,
},
[SystemType.CALENDAR]: {
name: "Calendar",
@@ -78,6 +85,7 @@ const ruleConfig: Record<
runOnThreads: false,
categoryAction: "label",
tooltipText: "Events, appointments, and reminders",
+ shouldLearn: true,
},
[SystemType.RECEIPT]: {
name: "Receipt",
@@ -88,6 +96,7 @@ const ruleConfig: Record<
categoryAction: "label",
categoryActionMicrosoft: "move_folder",
tooltipText: "Invoices, receipts, and payments",
+ shouldLearn: true,
},
[SystemType.NOTIFICATION]: {
name: "Notification",
@@ -97,6 +106,7 @@ const ruleConfig: Record<
categoryAction: "label",
categoryActionMicrosoft: "move_folder",
tooltipText: "Alerts, status updates, and system messages",
+ shouldLearn: true,
},
[SystemType.COLD_EMAIL]: {
name: "Cold Email",
@@ -107,6 +117,7 @@ const ruleConfig: Record<
categoryActionMicrosoft: "move_folder",
tooltipText:
"Unsolicited sales pitches and cold emails. We'll never block someone that's emailed you before",
+ shouldLearn: true,
},
};
@@ -124,6 +135,10 @@ export function getRuleLabel(systemType: SystemType) {
return getRuleConfig(systemType).label;
}
+export function shouldLearnFromLabelRemoval(systemType: SystemType): boolean {
+ return getRuleConfig(systemType).shouldLearn;
+}
+
export function getCategoryAction(systemType: SystemType, provider: string) {
const config = getRuleConfig(systemType);
diff --git a/apps/web/utils/rule/learned-patterns.test.ts b/apps/web/utils/rule/learned-patterns.test.ts
new file mode 100644
index 0000000000..89cee9ffdc
--- /dev/null
+++ b/apps/web/utils/rule/learned-patterns.test.ts
@@ -0,0 +1,253 @@
+import { describe, it, expect, vi, beforeEach } from "vitest";
+import { saveLearnedPattern, saveLearnedPatterns } from "./learned-patterns";
+import prisma from "@/utils/__mocks__/prisma";
+import { GroupItemType, GroupItemSource } from "@/generated/prisma/enums";
+import { isDuplicateError } from "@/utils/prisma-helpers";
+
+vi.mock("server-only", () => ({}));
+vi.mock("@/utils/prisma");
+
+vi.mock("@/utils/prisma-helpers", () => ({
+ isDuplicateError: vi.fn(),
+}));
+
+const mockLogger = {
+ error: vi.fn(),
+ warn: vi.fn(),
+ info: vi.fn(),
+} as any;
+
+describe("saveLearnedPattern", () => {
+ beforeEach(() => {
+ vi.clearAllMocks();
+ });
+
+ it("should return early if rule not found", async () => {
+ vi.mocked(prisma.rule.findUnique).mockResolvedValue(null);
+
+ await saveLearnedPattern({
+ emailAccountId: "email-account-id",
+ from: "test@example.com",
+ ruleId: "nonexistent-rule",
+ logger: mockLogger,
+ });
+
+ expect(mockLogger.error).toHaveBeenCalledWith("Rule not found", {
+ ruleId: "nonexistent-rule",
+ });
+ expect(prisma.groupItem.upsert).not.toHaveBeenCalled();
+ });
+
+ it("should use existing groupId when rule has one", async () => {
+ const existingGroupId = "existing-group-id";
+ vi.mocked(prisma.rule.findUnique).mockResolvedValue({
+ id: "rule-id",
+ name: "Test Rule",
+ groupId: existingGroupId,
+ } as any);
+ vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any);
+
+ await saveLearnedPattern({
+ emailAccountId: "email-account-id",
+ from: "test@example.com",
+ ruleId: "rule-id",
+ logger: mockLogger,
+ });
+
+ expect(prisma.group.create).not.toHaveBeenCalled();
+ expect(prisma.groupItem.upsert).toHaveBeenCalledWith({
+ where: {
+ groupId_type_value: {
+ groupId: existingGroupId,
+ type: GroupItemType.FROM,
+ value: "test@example.com",
+ },
+ },
+ update: expect.objectContaining({ exclude: false }),
+ create: expect.objectContaining({
+ groupId: existingGroupId,
+ type: GroupItemType.FROM,
+ value: "test@example.com",
+ }),
+ });
+ });
+
+ it("should create a new group when rule has no groupId", async () => {
+ const newGroupId = "new-group-id";
+ vi.mocked(prisma.rule.findUnique).mockResolvedValue({
+ id: "rule-id",
+ name: "Test Rule",
+ groupId: null,
+ } as any);
+ vi.mocked(prisma.group.create).mockResolvedValue({
+ id: newGroupId,
+ } as any);
+ vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any);
+
+ await saveLearnedPattern({
+ emailAccountId: "email-account-id",
+ from: "test@example.com",
+ ruleId: "rule-id",
+ logger: mockLogger,
+ });
+
+ expect(prisma.group.create).toHaveBeenCalledWith({
+ data: {
+ emailAccountId: "email-account-id",
+ name: "Test Rule",
+ rule: { connect: { id: "rule-id" } },
+ },
+ });
+ expect(prisma.groupItem.upsert).toHaveBeenCalledWith(
+ expect.objectContaining({
+ where: {
+ groupId_type_value: {
+ groupId: newGroupId,
+ type: GroupItemType.FROM,
+ value: "test@example.com",
+ },
+ },
+ }),
+ );
+ });
+
+ it("should save pattern with exclude: true", async () => {
+ vi.mocked(prisma.rule.findUnique).mockResolvedValue({
+ id: "rule-id",
+ name: "Test Rule",
+ groupId: "group-id",
+ } as any);
+ vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any);
+
+ await saveLearnedPattern({
+ emailAccountId: "email-account-id",
+ from: "excluded@example.com",
+ ruleId: "rule-id",
+ exclude: true,
+ logger: mockLogger,
+ reason: "User excluded",
+ source: GroupItemSource.USER,
+ });
+
+ expect(prisma.groupItem.upsert).toHaveBeenCalledWith({
+ where: {
+ groupId_type_value: {
+ groupId: "group-id",
+ type: GroupItemType.FROM,
+ value: "excluded@example.com",
+ },
+ },
+ update: {
+ exclude: true,
+ reason: "User excluded",
+ threadId: undefined,
+ messageId: undefined,
+ source: GroupItemSource.USER,
+ },
+ create: {
+ groupId: "group-id",
+ type: GroupItemType.FROM,
+ value: "excluded@example.com",
+ exclude: true,
+ reason: "User excluded",
+ threadId: undefined,
+ messageId: undefined,
+ source: GroupItemSource.USER,
+ },
+ });
+ });
+
+ it("should handle duplicate group creation by finding existing group", async () => {
+ const existingGroupId = "existing-group-id";
+ vi.mocked(prisma.rule.findUnique)
+ .mockResolvedValueOnce({
+ id: "rule-id",
+ name: "Test Rule",
+ groupId: null,
+ } as any)
+ .mockResolvedValueOnce({
+ groupId: null,
+ } as any);
+
+ const duplicateError = new Error("Duplicate key");
+ vi.mocked(prisma.group.create).mockRejectedValue(duplicateError);
+ vi.mocked(isDuplicateError).mockReturnValue(true);
+ vi.mocked(prisma.group.findUnique).mockResolvedValue({
+ id: existingGroupId,
+ } as any);
+ vi.mocked(prisma.rule.update).mockResolvedValue({} as any);
+ vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any);
+
+ await saveLearnedPattern({
+ emailAccountId: "email-account-id",
+ from: "test@example.com",
+ ruleId: "rule-id",
+ logger: mockLogger,
+ });
+
+ expect(prisma.group.findUnique).toHaveBeenCalledWith({
+ where: {
+ name_emailAccountId: {
+ name: "Test Rule",
+ emailAccountId: "email-account-id",
+ },
+ },
+ select: { id: true },
+ });
+ expect(prisma.groupItem.upsert).toHaveBeenCalledWith(
+ expect.objectContaining({
+ where: {
+ groupId_type_value: {
+ groupId: existingGroupId,
+ type: GroupItemType.FROM,
+ value: "test@example.com",
+ },
+ },
+ }),
+ );
+ });
+});
+
+describe("saveLearnedPatterns", () => {
+ beforeEach(() => {
+ vi.clearAllMocks();
+ });
+
+ it("should return error if rule not found", async () => {
+ vi.mocked(prisma.rule.findUnique).mockResolvedValue(null);
+
+ const result = await saveLearnedPatterns({
+ emailAccountId: "email-account-id",
+ ruleName: "Nonexistent Rule",
+ patterns: [{ type: GroupItemType.FROM, value: "test@example.com" }],
+ logger: mockLogger,
+ });
+
+ expect(result).toEqual({ error: "Rule not found" });
+ expect(mockLogger.error).toHaveBeenCalledWith("Rule not found", {
+ emailAccountId: "email-account-id",
+ ruleName: "Nonexistent Rule",
+ });
+ });
+
+ it("should save multiple patterns successfully", async () => {
+ vi.mocked(prisma.rule.findUnique).mockResolvedValue({
+ id: "rule-id",
+ groupId: "group-id",
+ } as any);
+ vi.mocked(prisma.groupItem.upsert).mockResolvedValue({} as any);
+
+ const result = await saveLearnedPatterns({
+ emailAccountId: "email-account-id",
+ ruleName: "Test Rule",
+ patterns: [
+ { type: GroupItemType.FROM, value: "sender1@example.com" },
+ { type: GroupItemType.SUBJECT, value: "Newsletter", exclude: true },
+ ],
+ logger: mockLogger,
+ });
+
+ expect(result).toEqual({ success: true });
+ expect(prisma.groupItem.upsert).toHaveBeenCalledTimes(2);
+ });
+});
diff --git a/apps/web/utils/rule/learned-patterns.ts b/apps/web/utils/rule/learned-patterns.ts
index 0517c21164..7742aa2b8e 100644
--- a/apps/web/utils/rule/learned-patterns.ts
+++ b/apps/web/utils/rule/learned-patterns.ts
@@ -1,6 +1,6 @@
import prisma from "@/utils/prisma";
import type { Logger } from "@/utils/logger";
-import { GroupItemType } from "@/generated/prisma/enums";
+import { GroupItemType, type GroupItemSource } from "@/generated/prisma/enums";
import { isDuplicateError } from "@/utils/prisma-helpers";
/**
@@ -11,43 +11,41 @@ import { isDuplicateError } from "@/utils/prisma-helpers";
export async function saveLearnedPattern({
emailAccountId,
from,
- ruleName,
+ ruleId,
+ exclude = false,
logger,
+ reason,
+ threadId,
+ messageId,
+ source,
}: {
emailAccountId: string;
from: string;
- ruleName: string;
+ ruleId: string;
+ exclude?: boolean;
logger: Logger;
+ reason?: string | null;
+ threadId?: string | null;
+ messageId?: string | null;
+ source?: GroupItemSource | null;
}) {
const rule = await prisma.rule.findUnique({
- where: {
- name_emailAccountId: {
- name: ruleName,
- emailAccountId,
- },
- },
- select: { id: true, groupId: true },
+ where: { id: ruleId, emailAccountId },
+ select: { id: true, name: true, groupId: true },
});
if (!rule) {
- logger.error("Rule not found", { emailAccountId, ruleName });
+ logger.error("Rule not found", { ruleId });
return;
}
- let groupId = rule.groupId;
-
- if (!groupId) {
- // Create a new group for this rule if one doesn't exist
- const newGroup = await prisma.group.create({
- data: {
- emailAccountId,
- name: ruleName,
- rule: { connect: { id: rule.id } },
- },
- });
-
- groupId = newGroup.id;
- }
+ const groupId = await getOrCreateGroupForRule({
+ emailAccountId,
+ ruleId: rule.id,
+ ruleName: rule.name,
+ existingGroupId: rule.groupId,
+ logger,
+ });
await prisma.groupItem.upsert({
where: {
@@ -57,11 +55,22 @@ export async function saveLearnedPattern({
value: from,
},
},
- update: {},
+ update: {
+ exclude,
+ reason,
+ threadId,
+ messageId,
+ source,
+ },
create: {
groupId,
type: GroupItemType.FROM,
value: from,
+ exclude,
+ reason,
+ threadId,
+ messageId,
+ source,
},
});
}
@@ -100,43 +109,24 @@ export async function saveLearnedPatterns({
return { error: "Rule not found" };
}
- let groupId = rule.groupId;
-
- if (!groupId) {
- try {
- const newGroup = await prisma.group.create({
- data: {
- emailAccountId,
- name: ruleName,
- rule: { connect: { id: rule.id } },
- },
- });
-
- groupId = newGroup.id;
- } catch (error) {
- if (isDuplicateError(error)) {
- logger.error("Group already exists", { emailAccountId, ruleName });
- const newGroup2 = await prisma.group.create({
- data: {
- emailAccountId,
- name: `${ruleName} (${new Date().toISOString()})`,
- rule: { connect: { id: rule.id } },
- },
- });
- groupId = newGroup2.id;
- } else {
- logger.error("Error creating learned patterns group", { error });
- return { error: "Error creating learned patterns group" };
- }
- }
+ let groupId: string;
+ try {
+ groupId = await getOrCreateGroupForRule({
+ emailAccountId,
+ ruleId: rule.id,
+ ruleName: ruleName,
+ existingGroupId: rule.groupId,
+ logger,
+ });
+ } catch (error) {
+ logger.error("Error creating learned patterns group", { error });
+ return { error: "Error creating learned patterns group" };
}
const errors: string[] = [];
// Process all patterns in a single function
for (const pattern of patterns) {
- // Store pattern with the exclude flag properly set in the database
- // This maps directly to the new exclude field in the GroupItem model
try {
await prisma.groupItem.upsert({
where: {
@@ -175,3 +165,65 @@ export async function saveLearnedPatterns({
return { success: true };
}
+
+async function getOrCreateGroupForRule({
+ emailAccountId,
+ ruleId,
+ ruleName,
+ existingGroupId,
+ logger,
+}: {
+ emailAccountId: string;
+ ruleId: string;
+ ruleName: string;
+ existingGroupId: string | null;
+ logger: Logger;
+}): Promise {
+ if (existingGroupId) return existingGroupId;
+
+ // Try to create the group
+ try {
+ const newGroup = await prisma.group.create({
+ data: {
+ emailAccountId,
+ name: ruleName,
+ rule: { connect: { id: ruleId } },
+ },
+ });
+ return newGroup.id;
+ } catch (error) {
+ if (!isDuplicateError(error)) throw error;
+ }
+
+ // Handle duplicate: check if rule was concurrently updated with a group
+ const updatedRule = await prisma.rule.findUnique({
+ where: { id: ruleId },
+ select: { groupId: true },
+ });
+ if (updatedRule?.groupId) return updatedRule.groupId;
+
+ // Check if a group with the same name exists
+ const existingGroup = await prisma.group.findUnique({
+ where: { name_emailAccountId: { name: ruleName, emailAccountId } },
+ select: { id: true },
+ });
+
+ if (existingGroup) {
+ // Attempt to link it (ignore failures from concurrent updates)
+ await prisma.rule
+ .update({ where: { id: ruleId }, data: { groupId: existingGroup.id } })
+ .catch((error) => {
+ logger.warn(
+ "Failed to link existing group to rule (likely concurrent update)",
+ {
+ ruleId,
+ groupId: existingGroup.id,
+ error,
+ },
+ );
+ });
+ return existingGroup.id;
+ }
+
+ throw new Error(`Failed to create or find group for rule: ${ruleName}`);
+}