Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/web/app/(app)/[emailAccountId]/simple/Summary.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { useCompletion } from "ai/react";
import { useCompletion } from "@ai-sdk/react";
import { useEffect } from "react";
import { ButtonLoader } from "@/components/Loading";
import { ViewMoreButton } from "@/app/(app)/[emailAccountId]/simple/ViewMoreButton";
Expand Down
112 changes: 28 additions & 84 deletions apps/web/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { appendClientMessage, appendResponseMessages } from "ai";
import { convertToModelMessages, type UIMessage } from "ai";
import { z } from "zod";
import { withEmailAccount } from "@/utils/middleware";
import { getEmailAccountWithAi } from "@/utils/user/get";
import { NextResponse } from "next/server";
import { aiProcessAssistantChat } from "@/utils/ai/assistant/chat";
import { createScopedLogger } from "@/utils/logger";
import prisma from "@/utils/prisma";
import { Prisma, type ChatMessage } from "@prisma/client";
import type { Prisma } from "@prisma/client";
import { convertToUIMessages } from "@/components/assistant-chat/helpers";
import { captureException } from "@/utils/error";

export const maxDuration = 120;

Expand All @@ -21,19 +23,8 @@ const assistantInputSchema = z.object({
id: z.string(),
message: z.object({
id: z.string(),
createdAt: z.coerce.date(),
role: z.enum(["user"]),
content: z.string().min(1).max(3000),
parts: z.array(textPartSchema),
// experimental_attachments: z
// .array(
// z.object({
// url: z.string().url(),
// name: z.string().min(1).max(100),
// contentType: z.enum(["image/png", "image/jpg", "image/jpeg"]),
// }),
// )
// .optional(),
}),
});

Expand All @@ -49,7 +40,6 @@ export const POST = withEmailAccount(async (request) => {

if (error) return NextResponse.json({ error: error.errors }, { status: 400 });

// create chat if it doesn't exist
const chat =
(await getChatById(data.id)) ||
(await createNewChat({ emailAccountId, chatId: data.id }));
Expand All @@ -69,63 +59,27 @@ export const POST = withEmailAccount(async (request) => {
}

const { message } = data;
const mappedDbMessages = chat.messages.map((dbMsg: ChatMessage) => {
return {
...dbMsg,
role: convertDbRoleToSdkRole(dbMsg.role),
content: "",
parts: dbMsg.parts as any,
};
});

const messages = appendClientMessage({
messages: mappedDbMessages,
message,
});
const uiMessages = [...convertToUIMessages(chat), message];

await saveChatMessage({
chat: { connect: { id: chat.id } },
id: message.id,
role: "user",
parts: message.parts,
// attachments: message.experimental_attachments ?? [],
});

try {
const result = await aiProcessAssistantChat({
messages,
messages: convertToModelMessages(uiMessages),
emailAccountId,
user,
onFinish: async ({ response }) => {
const assistantMessages = response.messages.filter(
(message) => message.role === "assistant",
);
const assistantId = getTrailingMessageId(assistantMessages);

if (!assistantId) {
logger.error("No assistant message found!", { response });
throw new Error("No assistant message found!");
}

// handles all tool calls
const [, assistantMessage] = appendResponseMessages({
messages: [message],
responseMessages: response.messages,
});

await saveChatMessage({
id: assistantId,
chat: { connect: { id: chat.id } },
role: assistantMessage.role,
parts: assistantMessage.parts
? (assistantMessage.parts as Prisma.InputJsonValue)
: Prisma.JsonNull,
// attachments: assistantMessage.experimental_attachments ?? [],
});
},
});

return result.toDataStreamResponse();
return result.toUIMessageStreamResponse({
onFinish: async ({ messages }) => {
await saveChatMessages(messages, chat.id);
},
});
} catch (error) {
logger.error("Error in assistant chat", { error });
return NextResponse.json(
Expand Down Expand Up @@ -155,10 +109,6 @@ async function createNewChat({
}
}

async function saveChatMessage(message: Prisma.ChatMessageCreateInput) {
return prisma.chatMessage.create({ data: message });
}

async function getChatById(chatId: string) {
const chat = await prisma.chat.findUnique({
where: { id: chatId },
Expand All @@ -167,29 +117,23 @@ async function getChatById(chatId: string) {
return chat;
}

function convertDbRoleToSdkRole(
role: string,
): "user" | "assistant" | "system" | "data" {
switch (role) {
case "user":
return "user";
case "assistant":
return "assistant";
case "system":
return "system";
case "data":
return "data";
default:
return "assistant";
}
async function saveChatMessage(message: Prisma.ChatMessageCreateInput) {
return prisma.chatMessage.create({ data: message });
}

function getTrailingMessageId<T extends { id: string }>(
messages: Array<T>,
): string | null {
const trailingMessage = messages.at(-1);

if (!trailingMessage) return null;

return trailingMessage.id;
async function saveChatMessages(messages: UIMessage[], chatId: string) {
try {
return prisma.chatMessage.createMany({
data: messages.map((message) => ({
id: message.id,
chatId,
role: message.role,
parts: message.parts as Prisma.InputJsonValue,
})),
});
} catch (error) {
logger.error("Failed to save chat messages", { error, chatId });
captureException(error, { extra: { chatId } });
throw error;
}
}
4 changes: 1 addition & 3 deletions apps/web/components/assistant-chat/ChatContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ export function ChatProvider({
export function useChat(): ChatContextType {
const context = useContext(ChatContext);
if (context === undefined) {
// TODO: throw error once this feature is live
// throw new Error("useChat must be used within a ChatProvider");
return { setInput: null };
throw new Error("useChat must be used within a ChatProvider");
}
return context;
}
Loading
Loading