Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions apps/web/app/(app)/[emailAccountId]/assistant/RuleForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import { useLabels } from "@/hooks/useLabels";
import { createLabelAction } from "@/utils/actions/mail";
import { MultiSelectFilter } from "@/components/MultiSelectFilter";
import { useCategories } from "@/hooks/useCategories";
import { hasVariables } from "@/utils/template";
import { hasVariables, TEMPLATE_VARIABLE_PATTERN } from "@/utils/template";
import { getEmptyCondition } from "@/utils/condition";
import { AlertError } from "@/components/Alert";
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group";
Expand Down Expand Up @@ -1291,7 +1291,7 @@ function ActionCard({
name={`actions.${index}.${field.name}.ai`}
labelRight="AI generated"
enabled={isAiGenerated || false}
onChange={(enabled: boolean) => {
onChange={(enabled) => {
setValue(
`actions.${index}.${field.name}`,
enabled
Expand All @@ -1312,7 +1312,9 @@ function ActionCard({
canFieldUseVariables(field, isAiGenerated) && (
<div className="mt-2 whitespace-pre-wrap rounded-md bg-muted/50 p-2 font-mono text-sm text-foreground">
{(value || "")
.split(/(\{\{.*?\}\})/g)
.split(
new RegExp(`(${TEMPLATE_VARIABLE_PATTERN})`, "g"),
)
.map((part: string, idx: number) =>
part.startsWith("{{") ? (
<span
Expand Down
5 changes: 4 additions & 1 deletion apps/web/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ export const env = createEnv({
NEXT_PUBLIC_LOG_SCOPES: z
.string()
.optional()
.transform((value) => value?.split(",")),
.transform((value) => {
if (!value) return;
return value.split(",");
}),
NEXT_PUBLIC_BEDROCK_SONNET_MODEL: z
.string()
.default("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
Expand Down
80 changes: 57 additions & 23 deletions apps/web/utils/ai/choose-rule/ai-choose-args.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { z } from "zod";
import { InvalidArgumentError } from "ai";
import { createGenerateText, withRetry } from "@/utils/llms";
import { createGenerateObject, withRetry } from "@/utils/llms";
import { stringifyEmail } from "@/utils/stringify-email";
import { createScopedLogger } from "@/utils/logger";
import type { EmailAccountWithAI } from "@/utils/llms/types";
import type { EmailForLLM, RuleWithActions } from "@/utils/types";
import type { ActionType } from "@prisma/client";
import { LogicalOperator, type ActionType } from "@prisma/client";
import { getModel, type ModelType } from "@/utils/llms/model";

/**
Expand Down Expand Up @@ -35,6 +35,14 @@ import { getModel, type ModelType } from "@/utils/llms/model";

const logger = createScopedLogger("AI Choose Args");

export type ActionArgResponse = {
[key: `${string}-${string}`]: {
[field: string]: {
[key: `var${number}`]: string;
};
};
};

export async function aiGenerateArgs({
email,
emailAccount,
Expand All @@ -53,7 +61,7 @@ export async function aiGenerateArgs({
>;
}[];
modelType: ModelType;
}) {
}): Promise<ActionArgResponse | undefined> {
const loggerOptions = {
email: emailAccount.email,
ruleId: selectedRule.id,
Expand All @@ -75,31 +83,24 @@ export async function aiGenerateArgs({

const modelOptions = getModel(emailAccount.user, modelType);

const generateText = createGenerateText({
const generateObject = createGenerateObject({
label: "Args for rule",
userEmail: emailAccount.email,
modelOptions,
});

const aiResponse = await withRetry(
() =>
generateText({
generateObject({
...modelOptions,
system,
prompt,
tools: {
apply_rule: {
description: "Apply the rule with the given arguments.",
inputSchema: z.object(
Object.fromEntries(
parameters.map((p) => [
`${p.type}-${p.actionId}`,
p.parameters,
]),
),
),
},
},
schemaDescription: "The arguments for the rule",
schema: z.object(
Object.fromEntries(
parameters.map((p) => [`${p.type}-${p.actionId}`, p.parameters]),
),
),
}),
{
retryIf: (error: unknown) => InvalidArgumentError.isInstance(error),
Expand All @@ -108,18 +109,16 @@ export async function aiGenerateArgs({
},
);

const toolCall = aiResponse.toolCalls?.[0];
const result = aiResponse.object;

if (!toolCall?.input) {
if (!result) {
logger.warn("No tool call found", {
...loggerOptions,
aiResponse,
});
return;
}

const result = toolCall.input;

return result;
}

Expand Down Expand Up @@ -155,10 +154,45 @@ function getPrompt({
return `Process this email according to the selected rule:

<selected_rule>
${selectedRule.instructions}
${printConditions(selectedRule)}
</selected_rule>

<email>
${stringifyEmail(email, 3000)}
</email>`;
}

function printConditions(condition: RuleWithActions) {
const result: string[] = [];
if (condition.instructions) {
result.push(`<match>${condition.instructions}</match>`);
}

const staticConditions = printStaticConditions(condition);
if (staticConditions) {
result.push(`<match>${staticConditions}</match>`);
}

return result.join(
condition.conditionalOperator === LogicalOperator.AND
? "\nAND\n"
: "\nOR\n",
);
}

function printStaticConditions(condition: RuleWithActions) {
const result: string[] = [];
if (condition.from) {
result.push(`From: ${condition.from}`);
}
if (condition.to) {
result.push(`To: ${condition.to}`);
}
if (condition.subject) {
result.push(`Subject: ${condition.subject}`);
}
if (condition.body) {
result.push(`Body: ${condition.body}`);
}
return result.join("\n");
}
19 changes: 5 additions & 14 deletions apps/web/utils/ai/choose-rule/choose-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,15 @@ import {
} from "@/utils/types";
import { fetchMessagesAndGenerateDraft } from "@/utils/reply-tracker/generate-draft";
import { getEmailForLLM } from "@/utils/get-email-from-message";
import { aiGenerateArgs } from "@/utils/ai/choose-rule/ai-choose-args";
import {
type ActionArgResponse,
aiGenerateArgs,
} from "@/utils/ai/choose-rule/ai-choose-args";
import { createScopedLogger } from "@/utils/logger";
import type { EmailProvider } from "@/utils/email/types";

const logger = createScopedLogger("choose-args");

type ActionArgResponse = {
[key: `${string}-${string}`]: {
[field: string]: {
[key: `var${number}`]: string;
};
};
};

export async function getActionItemsWithAiArgs({
message,
emailAccount,
Expand Down Expand Up @@ -84,11 +79,7 @@ export async function getActionItemsWithAiArgs({
modelType,
});

return combineActionsWithAiArgs(
selectedRule.actions,
result as ActionArgResponse,
draft,
);
return combineActionsWithAiArgs(selectedRule.actions, result, draft);
}

export function combineActionsWithAiArgs(
Expand Down
136 changes: 135 additions & 1 deletion apps/web/utils/risk.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { describe, it, expect, vi } from "vitest";
import { getRiskLevel, getActionRiskLevel } from "./risk";
import {
getRiskLevel,
getActionRiskLevel,
isFullyDynamicField,
isPartiallyDynamicField,
} from "./risk";
import { ActionType } from "@prisma/client";
import type { RulesResponse } from "@/app/api/user/rules/route";

Expand Down Expand Up @@ -217,3 +222,132 @@ describe("getRiskLevel", () => {
},
);
});

describe("isFullyDynamicField", () => {
const testCases = [
{
name: "returns true for single-line template variable",
field: "{{name}}",
expected: true,
},
{
name: "returns true for multi-line template variable",
field: `{{
tell a funny joke.
do it in the language of the questioner.
always start with "Here's a great joke:"
}}`,
expected: true,
},
{
name: "returns true for template variable with spaces",
field: "{{ write a greeting }}",
expected: true,
},
{
name: "returns false for partially dynamic field",
field: "Hello {{name}}",
expected: false,
},
{
name: "returns false for static field",
field: "Static content",
expected: false,
},
{
name: "returns false for empty string",
field: "",
expected: false,
},
{
name: "returns true for field with multiple template variables (starts and ends with braces)",
field: "{{greeting}} {{name}}",
expected: true,
},
{
name: "returns true for complex multi-line template",
field: `{{
Generate a personalized response that:
1. Acknowledges their request
2. Provides helpful information
3. Maintains a professional tone
}}`,
expected: true,
},
];

testCases.forEach(({ name, field, expected }) => {
it(name, () => {
expect(isFullyDynamicField(field)).toBe(expected);
});
});
});

describe("isPartiallyDynamicField", () => {
const testCases = [
{
name: "returns true for single-line template variable",
field: "{{name}}",
expected: true,
},
{
name: "returns true for multi-line template variable",
field: `{{
tell a funny joke.
do it in the language of the questioner.
always start with "Here's a great joke:"
}}`,
expected: true,
},
{
name: "returns true for partially dynamic field",
field: "Hello {{name}}",
expected: true,
},
{
name: "returns true for field with multiple template variables",
field: "{{greeting}} {{name}}",
expected: true,
},
{
name: "returns true for mixed content with multi-line template",
field: `Hi {{name}}!

{{
Please write a personalized response based on:
- Their previous interactions
- Their current needs
- Our company policies
}}

Best regards`,
expected: true,
},
{
name: "returns false for static field",
field: "Static content",
expected: false,
},
{
name: "returns false for empty string",
field: "",
expected: false,
},
{
name: "returns false for field with only curly braces (no double)",
field: "Hello {name}",
expected: false,
},
{
name: "returns false for field with malformed template syntax",
field: "Hello {{name}",
expected: false,
},
];

testCases.forEach(({ name, field, expected }) => {
it(name, () => {
expect(isPartiallyDynamicField(field)).toBe(expected);
});
});
});
10 changes: 6 additions & 4 deletions apps/web/utils/risk.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { RulesResponse } from "@/app/api/user/rules/route";
import { isAIRule, type RuleConditions } from "@/utils/condition";
import { ActionType } from "@prisma/client";
import { TEMPLATE_VARIABLE_PATTERN } from "@/utils/template";

const RISK_LEVELS = {
VERY_HIGH: "very-high",
Expand Down Expand Up @@ -164,10 +165,11 @@ function getFieldsDynamicStatus(action: RiskAction) {
}

// Helper functions
function isFullyDynamicField(field: string) {
return /^\{\{.*?\}\}$/.test(field);
export function isFullyDynamicField(field: string) {
const trimmed = field.trim();
return trimmed.startsWith("{{") && trimmed.endsWith("}}");
}

function isPartiallyDynamicField(field: string) {
return /\{\{.*?\}\}/.test(field);
export function isPartiallyDynamicField(field: string) {
return new RegExp(TEMPLATE_VARIABLE_PATTERN).test(field);
}
Loading
Loading